Skip to content

Commit 0f3c5e9

Browse files
authored
PERF: DataFrame(pytorch_tensor) (#45007)
1 parent 245eddf commit 0f3c5e9

File tree

5 files changed

+41
-3
lines changed

5 files changed

+41
-3
lines changed

asv_bench/benchmarks/frame_ctor.py

+17
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,21 @@ def time_frame_from_arrays_sparse(self):
182182
)
183183

184184

185+
class From3rdParty:
186+
# GH#44616
187+
188+
def setup(self):
189+
try:
190+
import torch
191+
except ImportError:
192+
raise NotImplementedError
193+
194+
row = 700000
195+
col = 64
196+
self.val_tensor = torch.randn(row, col)
197+
198+
def time_from_torch(self):
199+
DataFrame(self.val_tensor)
200+
201+
185202
from .pandas_vb_common import setup # noqa: F401 isort:skip

ci/deps/actions-38-db.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies:
3333
- pyarrow>=1.0.1
3434
- pymysql
3535
- pytables
36+
- pytorch
3637
- python-snappy
3738
- python-dateutil
3839
- pytz

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,7 @@ Performance improvements
602602
- Performance improvement in :func:`to_csv` when :class:`MultiIndex` contains a lot of unused levels (:issue:`37484`)
603603
- Performance improvement in :func:`read_csv` when ``index_col`` was set with a numeric column (:issue:`44158`)
604604
- Performance improvement in :func:`concat` (:issue:`43354`)
605+
- Performance improvement in constructing a :class:`DataFrame` from array-like objects like a ``Pytorch`` tensor (:issue:`44616`)
605606
-
606607

607608
.. ---------------------------------------------------------------------------

pandas/core/frame.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -702,11 +702,16 @@ def __init__(
702702
# For data is list-like, or Iterable (will consume into list)
703703
elif is_list_like(data):
704704
if not isinstance(data, (abc.Sequence, ExtensionArray)):
705-
data = list(data)
705+
if hasattr(data, "__array__"):
706+
# GH#44616 big perf improvement for e.g. pytorch tensor
707+
data = np.asarray(data)
708+
else:
709+
data = list(data)
706710
if len(data) > 0:
707711
if is_dataclass(data[0]):
708712
data = dataclasses_to_dicts(data)
709-
if treat_as_nested(data):
713+
if not isinstance(data, np.ndarray) and treat_as_nested(data):
714+
# exclude ndarray as we may have cast it a few lines above
710715
if columns is not None:
711716
# error: Argument 1 to "ensure_index" has incompatible type
712717
# "Collection[Any]"; expected "Union[Union[Union[ExtensionArray,

pandas/tests/test_downstream.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import subprocess
66
import sys
77

8-
import numpy as np # noqa:F401 needed in namespace for statsmodels
8+
import numpy as np
99
import pytest
1010

1111
import pandas.util._test_decorators as td
@@ -176,6 +176,20 @@ def test_pyarrow(df):
176176
tm.assert_frame_equal(result, df)
177177

178178

179+
def test_torch_frame_construction(using_array_manager):
180+
# GH#44616
181+
torch = import_module("torch")
182+
val_tensor = torch.randn(700, 64)
183+
184+
df = DataFrame(val_tensor)
185+
186+
if not using_array_manager:
187+
assert np.shares_memory(df, val_tensor)
188+
189+
ser = pd.Series(val_tensor[0])
190+
assert np.shares_memory(ser, val_tensor)
191+
192+
179193
def test_yaml_dump(df):
180194
# GH#42748
181195
yaml = import_module("yaml")

0 commit comments

Comments
 (0)