Skip to content

Commit bcb5024

Browse files
rgommerssteff456
andauthored
Add allow_copy flag to interchange protocol (#51)
Add a flag to throw an exception if the export cannot be zero-copy. (e.g. for pandas, possible due to block manager where rows are contiguous and columns are not) . - Add `allow_zero_copy` flag to the DataFrame class. - Propagate the flag to the buffer and raise a `RuntimeError` when needed - Fix `test_noncontiguous_columns` - Make update in the requirements doc Co-authored-by: Stephannie Jimenez <[email protected]>
1 parent d9419b1 commit bcb5024

File tree

2 files changed

+66
-31
lines changed

2 files changed

+66
-31
lines changed

protocol/dataframe_protocol.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,16 +354,23 @@ class DataFrame:
354354
``__dataframe__`` method of a public data frame class in a library adhering
355355
to the dataframe interchange protocol specification.
356356
"""
357-
def __dataframe__(self, nan_as_null : bool = False) -> dict:
357+
def __dataframe__(self, nan_as_null : bool = False,
358+
allow_copy : bool = True) -> dict:
358359
"""
359360
Produces a dictionary object following the dataframe protocol specification.
360361
361362
``nan_as_null`` is a keyword intended for the consumer to tell the
362363
producer to overwrite null values in the data with ``NaN`` (or ``NaT``).
363364
It is intended for cases where the consumer does not support the bit
364365
mask or byte mask that is the producer's native representation.
366+
367+
``allow_copy`` is a keyword that defines whether or not the library is
368+
allowed to make a copy of the data. For example, copying data would be
369+
necessary if a library supports strided buffers, given that this protocol
370+
specifies contiguous buffers.
365371
"""
366372
self._nan_as_null = nan_as_null
373+
self._allow_zero_zopy = allow_copy
367374
return {
368375
"dataframe": self, # DataFrame object adhering to the protocol
369376
"version": 0 # Version number of the protocol

protocol/pandas_implementation.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
ColumnObject = Any
3636

3737

38-
def from_dataframe(df : DataFrameObject) -> pd.DataFrame:
38+
def from_dataframe(df : DataFrameObject,
39+
allow_copy : bool = True) -> pd.DataFrame:
3940
"""
4041
Construct a pandas DataFrame from ``df`` if it supports ``__dataframe__``
4142
"""
@@ -46,7 +47,7 @@ def from_dataframe(df : DataFrameObject) -> pd.DataFrame:
4647
if not hasattr(df, '__dataframe__'):
4748
raise ValueError("`df` does not support __dataframe__")
4849

49-
return _from_dataframe(df.__dataframe__())
50+
return _from_dataframe(df.__dataframe__(allow_copy=allow_copy))
5051

5152

5253
def _from_dataframe(df : DataFrameObject) -> pd.DataFrame:
@@ -63,19 +64,24 @@ def _from_dataframe(df : DataFrameObject) -> pd.DataFrame:
6364
# least for now, deal with non-numpy dtypes later).
6465
columns = dict()
6566
_k = _DtypeKind
67+
_buffers = [] # hold on to buffers, keeps memory alive
6668
for name in df.column_names():
6769
col = df.get_column_by_name(name)
6870
if col.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL):
6971
# Simple numerical or bool dtype, turn into numpy array
70-
columns[name] = convert_column_to_ndarray(col)
72+
columns[name], _buf = convert_column_to_ndarray(col)
7173
elif col.dtype[0] == _k.CATEGORICAL:
72-
columns[name] = convert_categorical_column(col)
74+
columns[name], _buf = convert_categorical_column(col)
7375
elif col.dtype[0] == _k.STRING:
74-
columns[name] = convert_string_column(col)
76+
columns[name], _buf = convert_string_column(col)
7577
else:
7678
raise NotImplementedError(f"Data type {col.dtype[0]} not handled yet")
7779

78-
return pd.DataFrame(columns)
80+
_buffers.append(_buf)
81+
82+
df_new = pd.DataFrame(columns)
83+
df_new._buffers = _buffers
84+
return df_new
7985

8086

8187
class _DtypeKind(enum.IntEnum):
@@ -100,7 +106,7 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray:
100106
"sentinel values not handled yet")
101107

102108
_buffer, _dtype = col.get_buffers()["data"]
103-
return buffer_to_ndarray(_buffer, _dtype)
109+
return buffer_to_ndarray(_buffer, _dtype), _buffer
104110

105111

106112
def buffer_to_ndarray(_buffer, _dtype) -> np.ndarray:
@@ -159,7 +165,7 @@ def convert_categorical_column(col : ColumnObject) -> pd.Series:
159165
raise NotImplementedError("Only categorical columns with sentinel "
160166
"value supported at the moment")
161167

162-
return series
168+
return series, codes_buffer
163169

164170

165171
def convert_string_column(col : ColumnObject) -> np.ndarray:
@@ -218,10 +224,11 @@ def convert_string_column(col : ColumnObject) -> np.ndarray:
218224
str_list.append(s)
219225

220226
# Convert the string list to a NumPy array
221-
return np.asarray(str_list, dtype="object")
227+
return np.asarray(str_list, dtype="object"), buffers
222228

223229

224-
def __dataframe__(cls, nan_as_null : bool = False) -> dict:
230+
def __dataframe__(cls, nan_as_null : bool = False,
231+
allow_copy : bool = True) -> dict:
225232
"""
226233
The public method to attach to pd.DataFrame.
227234
@@ -232,12 +239,21 @@ def __dataframe__(cls, nan_as_null : bool = False) -> dict:
232239
producer to overwrite null values in the data with ``NaN`` (or ``NaT``).
233240
This currently has no effect; once support for nullable extension
234241
dtypes is added, this value should be propagated to columns.
242+
243+
``allow_copy`` is a keyword that defines whether or not the library is
244+
allowed to make a copy of the data. For example, copying data would be
245+
necessary if a library supports strided buffers, given that this protocol
246+
specifies contiguous buffers.
247+
Currently, if the flag is set to ``False`` and a copy is needed, a
248+
``RuntimeError`` will be raised.
235249
"""
236-
return _PandasDataFrame(cls, nan_as_null=nan_as_null)
250+
return _PandasDataFrame(
251+
cls, nan_as_null=nan_as_null, allow_copy=allow_copy)
237252

238253

239254
# Monkeypatch the Pandas DataFrame class to support the interchange protocol
240255
pd.DataFrame.__dataframe__ = __dataframe__
256+
pd.DataFrame._buffers = []
241257

242258

243259
# Implementation of interchange protocol
@@ -248,16 +264,18 @@ class _PandasBuffer:
248264
Data in the buffer is guaranteed to be contiguous in memory.
249265
"""
250266

251-
def __init__(self, x : np.ndarray) -> None:
267+
def __init__(self, x : np.ndarray, allow_copy : bool = True) -> None:
252268
"""
253269
Handle only regular columns (= numpy arrays) for now.
254270
"""
255271
if not x.strides == (x.dtype.itemsize,):
256-
# Array is not contiguous - this is possible to get in Pandas,
257-
# there was some discussion on whether to support it. Som extra
258-
# complexity for libraries that don't support it (e.g. Arrow),
259-
# but would help with numpy-based libraries like Pandas.
260-
raise RuntimeError("Design needs fixing - non-contiguous buffer")
272+
# The protocol does not support strided buffers, so a copy is
273+
# necessary. If that's not allowed, we need to raise an exception.
274+
if allow_copy:
275+
x = x.copy()
276+
else:
277+
raise RuntimeError("Exports cannot be zero-copy in the case "
278+
"of a non-contiguous buffer")
261279

262280
# Store the numpy array in which the data resides as a private
263281
# attribute, so we can use it to retrieve the public attributes
@@ -313,7 +331,8 @@ class _PandasColumn:
313331
314332
"""
315333

316-
def __init__(self, column : pd.Series) -> None:
334+
def __init__(self, column : pd.Series,
335+
allow_copy : bool = True) -> None:
317336
"""
318337
Note: doesn't deal with extension arrays yet, just assume a regular
319338
Series/ndarray for now.
@@ -324,6 +343,7 @@ def __init__(self, column : pd.Series) -> None:
324343

325344
# Store the column as a private attribute
326345
self._col = column
346+
self._allow_copy = allow_copy
327347

328348
@property
329349
def size(self) -> int:
@@ -560,11 +580,13 @@ def _get_data_buffer(self) -> Tuple[_PandasBuffer, Any]: # Any is for self.dtyp
560580
"""
561581
_k = _DtypeKind
562582
if self.dtype[0] in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL):
563-
buffer = _PandasBuffer(self._col.to_numpy())
583+
buffer = _PandasBuffer(
584+
self._col.to_numpy(), allow_copy=self._allow_copy)
564585
dtype = self.dtype
565586
elif self.dtype[0] == _k.CATEGORICAL:
566587
codes = self._col.values.codes
567-
buffer = _PandasBuffer(codes)
588+
buffer = _PandasBuffer(
589+
codes, allow_copy=self._allow_copy)
568590
dtype = self._dtype_from_pandasdtype(codes.dtype)
569591
elif self.dtype[0] == _k.STRING:
570592
# Marshal the strings from a NumPy object array into a byte array
@@ -677,7 +699,8 @@ class _PandasDataFrame:
677699
``pd.DataFrame.__dataframe__`` as objects with the methods and
678700
attributes defined on this class.
679701
"""
680-
def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None:
702+
def __init__(self, df : pd.DataFrame, nan_as_null : bool = False,
703+
allow_copy : bool = True) -> None:
681704
"""
682705
Constructor - an instance of this (private) class is returned from
683706
`pd.DataFrame.__dataframe__`.
@@ -688,6 +711,7 @@ def __init__(self, df : pd.DataFrame, nan_as_null : bool = False) -> None:
688711
# This currently has no effect; once support for nullable extension
689712
# dtypes is added, this value should be propagated to columns.
690713
self._nan_as_null = nan_as_null
714+
self._allow_copy = allow_copy
691715

692716
@property
693717
def metadata(self):
@@ -708,13 +732,16 @@ def column_names(self) -> Iterable[str]:
708732
return self._df.columns.tolist()
709733

710734
def get_column(self, i: int) -> _PandasColumn:
711-
return _PandasColumn(self._df.iloc[:, i])
735+
return _PandasColumn(
736+
self._df.iloc[:, i], allow_copy=self._allow_copy)
712737

713738
def get_column_by_name(self, name: str) -> _PandasColumn:
714-
return _PandasColumn(self._df[name])
739+
return _PandasColumn(
740+
self._df[name], allow_copy=self._allow_copy)
715741

716742
def get_columns(self) -> Iterable[_PandasColumn]:
717-
return [_PandasColumn(self._df[name]) for name in self._df.columns]
743+
return [_PandasColumn(self._df[name], allow_copy=self._allow_copy)
744+
for name in self._df.columns]
718745

719746
def select_columns(self, indices: Sequence[int]) -> '_PandasDataFrame':
720747
if not isinstance(indices, collections.Sequence):
@@ -752,13 +779,14 @@ def test_mixed_intfloat():
752779

753780

754781
def test_noncontiguous_columns():
755-
# Currently raises: TBD whether it should work or not, see code comment
756-
# where the RuntimeError is raised.
757782
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
758-
df = pd.DataFrame(arr)
759-
assert df[0].to_numpy().strides == (24,)
760-
pytest.raises(RuntimeError, from_dataframe, df)
761-
#df2 = from_dataframe(df)
783+
df = pd.DataFrame(arr, columns=['a', 'b', 'c'])
784+
assert df['a'].to_numpy().strides == (24,)
785+
df2 = from_dataframe(df) # uses default of allow_copy=True
786+
tm.assert_frame_equal(df, df2)
787+
788+
with pytest.raises(RuntimeError):
789+
from_dataframe(df, allow_copy=False)
762790

763791

764792
def test_categorical_dtype():

0 commit comments

Comments
 (0)