File tree 5 files changed +41
-3
lines changed
5 files changed +41
-3
lines changed Original file line number Diff line number Diff line change @@ -182,4 +182,21 @@ def time_frame_from_arrays_sparse(self):
182
182
)
183
183
184
184
185
+ class From3rdParty :
186
+ # GH#44616
187
+
188
+ def setup (self ):
189
+ try :
190
+ import torch
191
+ except ImportError :
192
+ raise NotImplementedError
193
+
194
+ row = 700000
195
+ col = 64
196
+ self .val_tensor = torch .randn (row , col )
197
+
198
+ def time_from_torch (self ):
199
+ DataFrame (self .val_tensor )
200
+
201
+
185
202
from .pandas_vb_common import setup # noqa: F401 isort:skip
Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ dependencies:
33
33
- pyarrow>=1.0.1
34
34
- pymysql
35
35
- pytables
36
+ - pytorch
36
37
- python-snappy
37
38
- python-dateutil
38
39
- pytz
Original file line number Diff line number Diff line change @@ -602,6 +602,7 @@ Performance improvements
602
602
- Performance improvement in :func: `to_csv ` when :class: `MultiIndex ` contains a lot of unused levels (:issue: `37484 `)
603
603
- Performance improvement in :func: `read_csv ` when ``index_col `` was set with a numeric column (:issue: `44158 `)
604
604
- Performance improvement in :func: `concat ` (:issue: `43354 `)
605
+ - Performance improvement in constructing a :class: `DataFrame ` from array-like objects like a ``Pytorch `` tensor (:issue: `44616 `)
605
606
-
606
607
607
608
.. ---------------------------------------------------------------------------
Original file line number Diff line number Diff line change @@ -702,11 +702,16 @@ def __init__(
702
702
# For data is list-like, or Iterable (will consume into list)
703
703
elif is_list_like (data ):
704
704
if not isinstance (data , (abc .Sequence , ExtensionArray )):
705
- data = list (data )
705
+ if hasattr (data , "__array__" ):
706
+ # GH#44616 big perf improvement for e.g. pytorch tensor
707
+ data = np .asarray (data )
708
+ else :
709
+ data = list (data )
706
710
if len (data ) > 0 :
707
711
if is_dataclass (data [0 ]):
708
712
data = dataclasses_to_dicts (data )
709
- if treat_as_nested (data ):
713
+ if not isinstance (data , np .ndarray ) and treat_as_nested (data ):
714
+ # exclude ndarray as we may have cast it a few lines above
710
715
if columns is not None :
711
716
# error: Argument 1 to "ensure_index" has incompatible type
712
717
# "Collection[Any]"; expected "Union[Union[Union[ExtensionArray,
Original file line number Diff line number Diff line change 5
5
import subprocess
6
6
import sys
7
7
8
- import numpy as np # noqa:F401 needed in namespace for statsmodels
8
+ import numpy as np
9
9
import pytest
10
10
11
11
import pandas .util ._test_decorators as td
@@ -176,6 +176,20 @@ def test_pyarrow(df):
176
176
tm .assert_frame_equal (result , df )
177
177
178
178
179
+ def test_torch_frame_construction (using_array_manager ):
180
+ # GH#44616
181
+ torch = import_module ("torch" )
182
+ val_tensor = torch .randn (700 , 64 )
183
+
184
+ df = DataFrame (val_tensor )
185
+
186
+ if not using_array_manager :
187
+ assert np .shares_memory (df , val_tensor )
188
+
189
+ ser = pd .Series (val_tensor [0 ])
190
+ assert np .shares_memory (ser , val_tensor )
191
+
192
+
179
193
def test_yaml_dump (df ):
180
194
# GH#42748
181
195
yaml = import_module ("yaml" )
You can’t perform that action at this time.
0 commit comments