35
35
ColumnObject = Any
36
36
37
37
38
- def from_dataframe (df : DataFrameObject ) -> pd .DataFrame :
38
+ def from_dataframe (df : DataFrameObject ,
39
+ allow_copy : bool = True ) -> pd .DataFrame :
39
40
"""
40
41
Construct a pandas DataFrame from ``df`` if it supports ``__dataframe__``
41
42
"""
@@ -46,7 +47,7 @@ def from_dataframe(df : DataFrameObject) -> pd.DataFrame:
46
47
if not hasattr (df , '__dataframe__' ):
47
48
raise ValueError ("`df` does not support __dataframe__" )
48
49
49
- return _from_dataframe (df .__dataframe__ ())
50
+ return _from_dataframe (df .__dataframe__ (allow_copy = allow_copy ))
50
51
51
52
52
53
def _from_dataframe (df : DataFrameObject ) -> pd .DataFrame :
@@ -63,19 +64,24 @@ def _from_dataframe(df : DataFrameObject) -> pd.DataFrame:
63
64
# least for now, deal with non-numpy dtypes later).
64
65
columns = dict ()
65
66
_k = _DtypeKind
67
+ _buffers = [] # hold on to buffers, keeps memory alive
66
68
for name in df .column_names ():
67
69
col = df .get_column_by_name (name )
68
70
if col .dtype [0 ] in (_k .INT , _k .UINT , _k .FLOAT , _k .BOOL ):
69
71
# Simple numerical or bool dtype, turn into numpy array
70
- columns [name ] = convert_column_to_ndarray (col )
72
+ columns [name ], _buf = convert_column_to_ndarray (col )
71
73
elif col .dtype [0 ] == _k .CATEGORICAL :
72
- columns [name ] = convert_categorical_column (col )
74
+ columns [name ], _buf = convert_categorical_column (col )
73
75
elif col .dtype [0 ] == _k .STRING :
74
- columns [name ] = convert_string_column (col )
76
+ columns [name ], _buf = convert_string_column (col )
75
77
else :
76
78
raise NotImplementedError (f"Data type { col .dtype [0 ]} not handled yet" )
77
79
78
- return pd .DataFrame (columns )
80
+ _buffers .append (_buf )
81
+
82
+ df_new = pd .DataFrame (columns )
83
+ df_new ._buffers = _buffers
84
+ return df_new
79
85
80
86
81
87
class _DtypeKind (enum .IntEnum ):
@@ -100,7 +106,7 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray:
100
106
"sentinel values not handled yet" )
101
107
102
108
_buffer , _dtype = col .get_buffers ()["data" ]
103
- return buffer_to_ndarray (_buffer , _dtype )
109
+ return buffer_to_ndarray (_buffer , _dtype ), _buffer
104
110
105
111
106
112
def buffer_to_ndarray (_buffer , _dtype ) -> np .ndarray :
@@ -159,7 +165,7 @@ def convert_categorical_column(col : ColumnObject) -> pd.Series:
159
165
raise NotImplementedError ("Only categorical columns with sentinel "
160
166
"value supported at the moment" )
161
167
162
- return series
168
+ return series , codes_buffer
163
169
164
170
165
171
def convert_string_column (col : ColumnObject ) -> np .ndarray :
@@ -218,10 +224,11 @@ def convert_string_column(col : ColumnObject) -> np.ndarray:
218
224
str_list .append (s )
219
225
220
226
# Convert the string list to a NumPy array
221
- return np .asarray (str_list , dtype = "object" )
227
+ return np .asarray (str_list , dtype = "object" ), buffers
222
228
223
229
224
- def __dataframe__ (cls , nan_as_null : bool = False ) -> dict :
230
+ def __dataframe__ (cls , nan_as_null : bool = False ,
231
+ allow_copy : bool = True ) -> dict :
225
232
"""
226
233
The public method to attach to pd.DataFrame.
227
234
@@ -232,12 +239,21 @@ def __dataframe__(cls, nan_as_null : bool = False) -> dict:
232
239
producer to overwrite null values in the data with ``NaN`` (or ``NaT``).
233
240
This currently has no effect; once support for nullable extension
234
241
dtypes is added, this value should be propagated to columns.
242
+
243
+ ``allow_copy`` is a keyword that defines whether or not the library is
244
+ allowed to make a copy of the data. For example, copying data would be
245
+ necessary if a library supports strided buffers, given that this protocol
246
+ specifies contiguous buffers.
247
+ Currently, if the flag is set to ``False`` and a copy is needed, a
248
+ ``RuntimeError`` will be raised.
235
249
"""
236
- return _PandasDataFrame (cls , nan_as_null = nan_as_null )
250
+ return _PandasDataFrame (
251
+ cls , nan_as_null = nan_as_null , allow_copy = allow_copy )
237
252
238
253
239
254
# Monkeypatch the Pandas DataFrame class to support the interchange protocol
240
255
pd .DataFrame .__dataframe__ = __dataframe__
256
+ pd .DataFrame ._buffers = []
241
257
242
258
243
259
# Implementation of interchange protocol
@@ -248,16 +264,18 @@ class _PandasBuffer:
248
264
Data in the buffer is guaranteed to be contiguous in memory.
249
265
"""
250
266
251
- def __init__ (self , x : np .ndarray ) -> None :
267
+ def __init__ (self , x : np .ndarray , allow_copy : bool = True ) -> None :
252
268
"""
253
269
Handle only regular columns (= numpy arrays) for now.
254
270
"""
255
271
if not x .strides == (x .dtype .itemsize ,):
256
- # Array is not contiguous - this is possible to get in Pandas,
257
- # there was some discussion on whether to support it. Som extra
258
- # complexity for libraries that don't support it (e.g. Arrow),
259
- # but would help with numpy-based libraries like Pandas.
260
- raise RuntimeError ("Design needs fixing - non-contiguous buffer" )
272
+ # The protocol does not support strided buffers, so a copy is
273
+ # necessary. If that's not allowed, we need to raise an exception.
274
+ if allow_copy :
275
+ x = x .copy ()
276
+ else :
277
+ raise RuntimeError ("Exports cannot be zero-copy in the case "
278
+ "of a non-contiguous buffer" )
261
279
262
280
# Store the numpy array in which the data resides as a private
263
281
# attribute, so we can use it to retrieve the public attributes
@@ -313,7 +331,8 @@ class _PandasColumn:
313
331
314
332
"""
315
333
316
- def __init__ (self , column : pd .Series ) -> None :
334
+ def __init__ (self , column : pd .Series ,
335
+ allow_copy : bool = True ) -> None :
317
336
"""
318
337
Note: doesn't deal with extension arrays yet, just assume a regular
319
338
Series/ndarray for now.
@@ -324,6 +343,7 @@ def __init__(self, column : pd.Series) -> None:
324
343
325
344
# Store the column as a private attribute
326
345
self ._col = column
346
+ self ._allow_copy = allow_copy
327
347
328
348
@property
329
349
def size (self ) -> int :
@@ -560,11 +580,13 @@ def _get_data_buffer(self) -> Tuple[_PandasBuffer, Any]: # Any is for self.dtyp
560
580
"""
561
581
_k = _DtypeKind
562
582
if self .dtype [0 ] in (_k .INT , _k .UINT , _k .FLOAT , _k .BOOL ):
563
- buffer = _PandasBuffer (self ._col .to_numpy ())
583
+ buffer = _PandasBuffer (
584
+ self ._col .to_numpy (), allow_copy = self ._allow_copy )
564
585
dtype = self .dtype
565
586
elif self .dtype [0 ] == _k .CATEGORICAL :
566
587
codes = self ._col .values .codes
567
- buffer = _PandasBuffer (codes )
588
+ buffer = _PandasBuffer (
589
+ codes , allow_copy = self ._allow_copy )
568
590
dtype = self ._dtype_from_pandasdtype (codes .dtype )
569
591
elif self .dtype [0 ] == _k .STRING :
570
592
# Marshal the strings from a NumPy object array into a byte array
@@ -677,7 +699,8 @@ class _PandasDataFrame:
677
699
``pd.DataFrame.__dataframe__`` as objects with the methods and
678
700
attributes defined on this class.
679
701
"""
680
- def __init__ (self , df : pd .DataFrame , nan_as_null : bool = False ) -> None :
702
+ def __init__ (self , df : pd .DataFrame , nan_as_null : bool = False ,
703
+ allow_copy : bool = True ) -> None :
681
704
"""
682
705
Constructor - an instance of this (private) class is returned from
683
706
`pd.DataFrame.__dataframe__`.
@@ -688,6 +711,7 @@ def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None:
688
711
# This currently has no effect; once support for nullable extension
689
712
# dtypes is added, this value should be propagated to columns.
690
713
self ._nan_as_null = nan_as_null
714
+ self ._allow_copy = allow_copy
691
715
692
716
@property
693
717
def metadata (self ):
@@ -708,13 +732,16 @@ def column_names(self) -> Iterable[str]:
708
732
return self ._df .columns .tolist ()
709
733
710
734
def get_column (self , i : int ) -> _PandasColumn :
711
- return _PandasColumn (self ._df .iloc [:, i ])
735
+ return _PandasColumn (
736
+ self ._df .iloc [:, i ], allow_copy = self ._allow_copy )
712
737
713
738
def get_column_by_name (self , name : str ) -> _PandasColumn :
714
- return _PandasColumn (self ._df [name ])
739
+ return _PandasColumn (
740
+ self ._df [name ], allow_copy = self ._allow_copy )
715
741
716
742
def get_columns (self ) -> Iterable [_PandasColumn ]:
717
- return [_PandasColumn (self ._df [name ]) for name in self ._df .columns ]
743
+ return [_PandasColumn (self ._df [name ], allow_copy = self ._allow_copy )
744
+ for name in self ._df .columns ]
718
745
719
746
def select_columns (self , indices : Sequence [int ]) -> '_PandasDataFrame' :
720
747
if not isinstance (indices , collections .Sequence ):
@@ -752,13 +779,14 @@ def test_mixed_intfloat():
752
779
753
780
754
781
def test_noncontiguous_columns ():
755
- # Currently raises: TBD whether it should work or not, see code comment
756
- # where the RuntimeError is raised.
757
782
arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]])
758
- df = pd .DataFrame (arr )
759
- assert df [0 ].to_numpy ().strides == (24 ,)
760
- pytest .raises (RuntimeError , from_dataframe , df )
761
- #df2 = from_dataframe(df)
783
+ df = pd .DataFrame (arr , columns = ['a' , 'b' , 'c' ])
784
+ assert df ['a' ].to_numpy ().strides == (24 ,)
785
+ df2 = from_dataframe (df ) # uses default of allow_copy=True
786
+ tm .assert_frame_equal (df , df2 )
787
+
788
+ with pytest .raises (RuntimeError ):
789
+ from_dataframe (df , allow_copy = False )
762
790
763
791
764
792
def test_categorical_dtype ():
0 commit comments