Skip to content

Commit fd557f2

Browse files
committed
fix pyarrow interchange
1 parent dc19148 commit fd557f2

File tree

5 files changed

+277
-46
lines changed

5 files changed

+277
-46
lines changed

doc/source/whatsnew/v2.2.2.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@ including other versions of pandas.
1414
Fixed regressions
1515
~~~~~~~~~~~~~~~~~
1616
- :meth:`DataFrame.__dataframe__` was producing incorrect data buffers when the a column's type was a pandas nullable on with missing values (:issue:`56702`)
17+
- :meth:`DataFrame.__dataframe__` was producing incorrect data buffers when the a column's type was a pyarrow nullable on with missing values (:issue:`57664`)
1718
-
1819

1920
.. ---------------------------------------------------------------------------
2021
.. _whatsnew_222.bug_fixes:
2122

2223
Bug fixes
2324
~~~~~~~~~
24-
-
25+
- :meth:`DataFrame.__dataframe__` was showing bytemask instead of bitmask for ``'string[pyarrow]'`` validity buffer (:issue:`57762`)
26+
- :meth:`DataFrame.__dataframe__` was showing non-null validity buffer (instead of ``None``) ``'string[pyarrow]'`` without missing values (:issue:`57761`)
2527

2628
.. ---------------------------------------------------------------------------
2729
.. _whatsnew_222.other:

pandas/core/interchange/buffer.py

+76
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
if TYPE_CHECKING:
1414
import numpy as np
15+
import pyarrow as pa
1516

1617

1718
class PandasBuffer(Buffer):
@@ -76,3 +77,78 @@ def __repr__(self) -> str:
7677
)
7778
+ ")"
7879
)
80+
81+
82+
class PandasBufferPyarrow(Buffer):
83+
"""
84+
Data in the buffer is guaranteed to be contiguous in memory.
85+
"""
86+
87+
def __init__(
88+
self,
89+
chunked_array: pa.ChunkedArray,
90+
*,
91+
is_validity: bool,
92+
allow_copy: bool = True,
93+
) -> None:
94+
"""
95+
Handle pyarrow chunked arrays.
96+
"""
97+
if len(chunked_array.chunks) == 1:
98+
arr = chunked_array.chunks[0]
99+
else:
100+
if not allow_copy:
101+
raise RuntimeError(
102+
"Found multi-chunk pyarrow array, but `allow_copy` is False"
103+
)
104+
arr = chunked_array.combine_chunks()
105+
if is_validity:
106+
self._buffer = arr.buffers()[0]
107+
else:
108+
self._buffer = arr.buffers()[1]
109+
self._length = len(arr)
110+
self._dlpack = getattr(arr, "__dlpack__", None)
111+
self._is_validity = is_validity
112+
113+
@property
114+
def bufsize(self) -> int:
115+
"""
116+
Buffer size in bytes.
117+
"""
118+
return self._buffer.size
119+
120+
@property
121+
def ptr(self) -> int:
122+
"""
123+
Pointer to start of the buffer as an integer.
124+
"""
125+
return self._buffer.address
126+
127+
def __dlpack__(self) -> Any:
128+
"""
129+
Represent this structure as DLPack interface.
130+
"""
131+
if self._dlpack is not None:
132+
return self._dlpack()
133+
raise NotImplementedError(
134+
"pyarrow>=15.0.0 is required for DLPack support for pyarrow-backed buffers"
135+
)
136+
137+
def __dlpack_device__(self) -> tuple[DlpackDeviceType, int | None]:
138+
"""
139+
Device type and device ID for where the data in the buffer resides.
140+
"""
141+
return (DlpackDeviceType.CPU, None)
142+
143+
def __repr__(self) -> str:
144+
return (
145+
"PandasBuffer[pyarrow]("
146+
+ str(
147+
{
148+
"bufsize": self.bufsize,
149+
"ptr": self.ptr,
150+
"device": "CPU",
151+
}
152+
)
153+
+ ")"
154+
)

pandas/core/interchange/column.py

+63-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
)
47

58
import numpy as np
69

@@ -9,15 +12,18 @@
912
from pandas.errors import NoBufferPresent
1013
from pandas.util._decorators import cache_readonly
1114

12-
from pandas.core.dtypes.dtypes import (
15+
from pandas.core.dtypes.dtypes import BaseMaskedDtype
16+
17+
import pandas as pd
18+
from pandas import (
1319
ArrowDtype,
14-
BaseMaskedDtype,
1520
DatetimeTZDtype,
1621
)
17-
18-
import pandas as pd
1922
from pandas.api.types import is_string_dtype
20-
from pandas.core.interchange.buffer import PandasBuffer
23+
from pandas.core.interchange.buffer import (
24+
PandasBuffer,
25+
PandasBufferPyarrow,
26+
)
2127
from pandas.core.interchange.dataframe_protocol import (
2228
Column,
2329
ColumnBuffers,
@@ -30,6 +36,9 @@
3036
dtype_to_arrow_c_fmt,
3137
)
3238

39+
if TYPE_CHECKING:
40+
from pandas.core.interchange.dataframe_protocol import Buffer
41+
3342
_NP_KINDS = {
3443
"i": DtypeKind.INT,
3544
"u": DtypeKind.UINT,
@@ -157,6 +166,14 @@ def _dtype_from_pandasdtype(self, dtype) -> tuple[DtypeKind, int, str, str]:
157166
else:
158167
byteorder = dtype.byteorder
159168

169+
if dtype == "bool[pyarrow]":
170+
return (
171+
kind,
172+
dtype.itemsize, # pyright: ignore[reportAttributeAccessIssue]
173+
ArrowCTypes.BOOL,
174+
byteorder,
175+
)
176+
160177
return kind, dtype.itemsize * 8, dtype_to_arrow_c_fmt(dtype), byteorder
161178

162179
@property
@@ -194,6 +211,13 @@ def describe_null(self):
194211
column_null_dtype = ColumnNullType.USE_BYTEMASK
195212
null_value = 1
196213
return column_null_dtype, null_value
214+
if isinstance(self._col.dtype, ArrowDtype):
215+
if all(
216+
chunk.buffers()[0] is None
217+
for chunk in self._col.array._pa_array.chunks # type: ignore[attr-defined]
218+
):
219+
return ColumnNullType.NON_NULLABLE, None
220+
return ColumnNullType.USE_BITMASK, 0
197221
kind = self.dtype[0]
198222
try:
199223
null, value = _NULL_DESCRIPTION[kind]
@@ -278,7 +302,7 @@ def get_buffers(self) -> ColumnBuffers:
278302

279303
def _get_data_buffer(
280304
self,
281-
) -> tuple[PandasBuffer, Any]: # Any is for self.dtype tuple
305+
) -> tuple[Buffer, tuple[DtypeKind, int, str, str]]:
282306
"""
283307
Return the buffer containing the data and the buffer's associated dtype.
284308
"""
@@ -289,7 +313,7 @@ def _get_data_buffer(
289313
np_arr = self._col.dt.tz_convert(None).to_numpy()
290314
else:
291315
np_arr = self._col.to_numpy()
292-
buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
316+
buffer: Buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
293317
dtype = (
294318
DtypeKind.INT,
295319
64,
@@ -302,15 +326,27 @@ def _get_data_buffer(
302326
DtypeKind.FLOAT,
303327
DtypeKind.BOOL,
304328
):
329+
dtype = self.dtype
305330
arr = self._col.array
331+
if isinstance(self._col.dtype, ArrowDtype):
332+
buffer = PandasBufferPyarrow(
333+
arr._pa_array, # type: ignore[attr-defined]
334+
is_validity=False,
335+
allow_copy=self._allow_copy,
336+
)
337+
if self.dtype[0] == DtypeKind.BOOL:
338+
dtype = (
339+
DtypeKind.BOOL,
340+
1,
341+
ArrowCTypes.BOOL,
342+
Endianness.NATIVE,
343+
)
344+
return buffer, dtype
306345
if isinstance(self._col.dtype, BaseMaskedDtype):
307346
np_arr = arr._data # type: ignore[attr-defined]
308-
elif isinstance(self._col.dtype, ArrowDtype):
309-
raise NotImplementedError("ArrowDtype not handled yet")
310347
else:
311348
np_arr = arr._ndarray # type: ignore[attr-defined]
312349
buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
313-
dtype = self.dtype
314350
elif self.dtype[0] == DtypeKind.CATEGORICAL:
315351
codes = self._col.values._codes
316352
buffer = PandasBuffer(codes, allow_copy=self._allow_copy)
@@ -343,14 +379,29 @@ def _get_data_buffer(
343379

344380
return buffer, dtype
345381

346-
def _get_validity_buffer(self) -> tuple[PandasBuffer, Any]:
382+
def _get_validity_buffer(self) -> tuple[Buffer, Any] | None:
347383
"""
348384
Return the buffer containing the mask values indicating missing data and
349385
the buffer's associated dtype.
350386
Raises NoBufferPresent if null representation is not a bit or byte mask.
351387
"""
352388
null, invalid = self.describe_null
353389

390+
if isinstance(self._col.dtype, ArrowDtype):
391+
arr = self._col.array
392+
dtype = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, Endianness.NATIVE)
393+
if all(
394+
chunk.buffers()[0] is None
395+
for chunk in arr._pa_array.chunks # type: ignore[attr-defined]
396+
):
397+
return None
398+
buffer: Buffer = PandasBufferPyarrow(
399+
arr._pa_array, # type: ignore[attr-defined]
400+
is_validity=True,
401+
allow_copy=self._allow_copy,
402+
)
403+
return buffer, dtype
404+
354405
if isinstance(self._col.dtype, BaseMaskedDtype):
355406
mask = self._col.array._mask # type: ignore[attr-defined]
356407
buffer = PandasBuffer(mask)

pandas/core/interchange/from_dataframe.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,14 @@ def string_column_to_ndarray(col: Column) -> tuple[np.ndarray, Any]:
298298

299299
null_pos = None
300300
if null_kind in (ColumnNullType.USE_BITMASK, ColumnNullType.USE_BYTEMASK):
301-
assert buffers["validity"], "Validity buffers cannot be empty for masks"
302-
valid_buff, valid_dtype = buffers["validity"]
303-
null_pos = buffer_to_ndarray(
304-
valid_buff, valid_dtype, offset=col.offset, length=col.size()
305-
)
306-
if sentinel_val == 0:
307-
null_pos = ~null_pos
301+
validity = buffers["validity"]
302+
if validity is not None:
303+
valid_buff, valid_dtype = validity
304+
null_pos = buffer_to_ndarray(
305+
valid_buff, valid_dtype, offset=col.offset, length=col.size()
306+
)
307+
if sentinel_val == 0:
308+
null_pos = ~null_pos
308309

309310
# Assemble the strings from the code units
310311
str_list: list[None | float | str] = [None] * col.size()
@@ -516,19 +517,21 @@ def set_nulls(
516517
np.ndarray or pd.Series
517518
Data with the nulls being set.
518519
"""
520+
if validity is None:
521+
return data
519522
null_kind, sentinel_val = col.describe_null
520523
null_pos = None
521524

522525
if null_kind == ColumnNullType.USE_SENTINEL:
523526
null_pos = pd.Series(data) == sentinel_val
524527
elif null_kind in (ColumnNullType.USE_BITMASK, ColumnNullType.USE_BYTEMASK):
525-
assert validity, "Expected to have a validity buffer for the mask"
526528
valid_buff, valid_dtype = validity
527-
null_pos = buffer_to_ndarray(
528-
valid_buff, valid_dtype, offset=col.offset, length=col.size()
529-
)
530-
if sentinel_val == 0:
531-
null_pos = ~null_pos
529+
if valid_buff is not None:
530+
null_pos = buffer_to_ndarray(
531+
valid_buff, valid_dtype, offset=col.offset, length=col.size()
532+
)
533+
if sentinel_val == 0:
534+
null_pos = ~null_pos
532535
elif null_kind in (ColumnNullType.NON_NULLABLE, ColumnNullType.USE_NAN):
533536
pass
534537
else:

0 commit comments

Comments
 (0)