diff --git a/asv_bench/benchmarks/frame_ctor.py b/asv_bench/benchmarks/frame_ctor.py index 912971257490c..eace665ba0bac 100644 --- a/asv_bench/benchmarks/frame_ctor.py +++ b/asv_bench/benchmarks/frame_ctor.py @@ -182,4 +182,21 @@ def time_frame_from_arrays_sparse(self): ) +class From3rdParty: + # GH#44616 + + def setup(self): + try: + import torch + except ImportError: + raise NotImplementedError + + row = 700000 + col = 64 + self.val_tensor = torch.randn(row, col) + + def time_from_torch(self): + DataFrame(self.val_tensor) + + from .pandas_vb_common import setup # noqa: F401 isort:skip diff --git a/ci/deps/actions-38-db.yaml b/ci/deps/actions-38-db.yaml index 05b9eb8446af8..f445225a44dcb 100644 --- a/ci/deps/actions-38-db.yaml +++ b/ci/deps/actions-38-db.yaml @@ -33,6 +33,7 @@ dependencies: - pyarrow>=1.0.1 - pymysql - pytables + - pytorch - python-snappy - python-dateutil - pytz diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index cdc0bbb1dfd6a..85442a876b988 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -601,6 +601,7 @@ Performance improvements - Performance improvement in :func:`to_csv` when :class:`MultiIndex` contains a lot of unused levels (:issue:`37484`) - Performance improvement in :func:`read_csv` when ``index_col`` was set with a numeric column (:issue:`44158`) - Performance improvement in :func:`concat` (:issue:`43354`) +- Performance improvement in constructing a :class:`DataFrame` from array-like objects like a ``Pytorch`` tensor (:issue:`44616`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 34078d552e0b3..2370b9ea16d94 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -702,11 +702,16 @@ def __init__( # For data is list-like, or Iterable (will consume into list) elif is_list_like(data): if not isinstance(data, (abc.Sequence, ExtensionArray)): - data = list(data) + if hasattr(data, "__array__"): + # GH#44616 big perf improvement for e.g. pytorch tensor + data = np.asarray(data) + else: + data = list(data) if len(data) > 0: if is_dataclass(data[0]): data = dataclasses_to_dicts(data) - if treat_as_nested(data): + if not isinstance(data, np.ndarray) and treat_as_nested(data): + # exclude ndarray as we may have cast it a few lines above if columns is not None: # error: Argument 1 to "ensure_index" has incompatible type # "Collection[Any]"; expected "Union[Union[Union[ExtensionArray, diff --git a/pandas/tests/test_downstream.py b/pandas/tests/test_downstream.py index 1972fbbe0f414..3880b9ecd9da7 100644 --- a/pandas/tests/test_downstream.py +++ b/pandas/tests/test_downstream.py @@ -5,7 +5,7 @@ import subprocess import sys -import numpy as np # noqa:F401 needed in namespace for statsmodels +import numpy as np import pytest import pandas.util._test_decorators as td @@ -176,6 +176,20 @@ def test_pyarrow(df): tm.assert_frame_equal(result, df) +def test_torch_frame_construction(using_array_manager): + # GH#44616 + torch = import_module("torch") + val_tensor = torch.randn(700, 64) + + df = DataFrame(val_tensor) + + if not using_array_manager: + assert np.shares_memory(df, val_tensor) + + ser = pd.Series(val_tensor[0]) + assert np.shares_memory(ser, val_tensor) + + def test_yaml_dump(df): # GH#42748 yaml = import_module("yaml")