Skip to content

Commit 7e8d492

Browse files
authored
Backport PR #57764 on branch 2.2.x (BUG: PyArrow dtypes were not supported in the interchange protocol) (#57947)
1 parent 78f7a02 commit 7e8d492

File tree

7 files changed

+326
-38
lines changed

7 files changed

+326
-38
lines changed

doc/source/whatsnew/v2.2.2.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@ including other versions of pandas.
1414
Fixed regressions
1515
~~~~~~~~~~~~~~~~~
1616
- :meth:`DataFrame.__dataframe__` was producing incorrect data buffers when the a column's type was a pandas nullable on with missing values (:issue:`56702`)
17+
- :meth:`DataFrame.__dataframe__` was producing incorrect data buffers when the a column's type was a pyarrow nullable on with missing values (:issue:`57664`)
1718
-
1819

1920
.. ---------------------------------------------------------------------------
2021
.. _whatsnew_222.bug_fixes:
2122

2223
Bug fixes
2324
~~~~~~~~~
24-
-
25+
- :meth:`DataFrame.__dataframe__` was showing bytemask instead of bitmask for ``'string[pyarrow]'`` validity buffer (:issue:`57762`)
26+
- :meth:`DataFrame.__dataframe__` was showing non-null validity buffer (instead of ``None``) ``'string[pyarrow]'`` without missing values (:issue:`57761`)
2527

2628
.. ---------------------------------------------------------------------------
2729
.. _whatsnew_222.other:

pandas/core/interchange/buffer.py

+58
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
if TYPE_CHECKING:
1414
import numpy as np
15+
import pyarrow as pa
1516

1617

1718
class PandasBuffer(Buffer):
@@ -76,3 +77,60 @@ def __repr__(self) -> str:
7677
)
7778
+ ")"
7879
)
80+
81+
82+
class PandasBufferPyarrow(Buffer):
83+
"""
84+
Data in the buffer is guaranteed to be contiguous in memory.
85+
"""
86+
87+
def __init__(
88+
self,
89+
buffer: pa.Buffer,
90+
*,
91+
length: int,
92+
) -> None:
93+
"""
94+
Handle pyarrow chunked arrays.
95+
"""
96+
self._buffer = buffer
97+
self._length = length
98+
99+
@property
100+
def bufsize(self) -> int:
101+
"""
102+
Buffer size in bytes.
103+
"""
104+
return self._buffer.size
105+
106+
@property
107+
def ptr(self) -> int:
108+
"""
109+
Pointer to start of the buffer as an integer.
110+
"""
111+
return self._buffer.address
112+
113+
def __dlpack__(self) -> Any:
114+
"""
115+
Represent this structure as DLPack interface.
116+
"""
117+
raise NotImplementedError()
118+
119+
def __dlpack_device__(self) -> tuple[DlpackDeviceType, int | None]:
120+
"""
121+
Device type and device ID for where the data in the buffer resides.
122+
"""
123+
return (DlpackDeviceType.CPU, None)
124+
125+
def __repr__(self) -> str:
126+
return (
127+
"PandasBuffer[pyarrow]("
128+
+ str(
129+
{
130+
"bufsize": self.bufsize,
131+
"ptr": self.ptr,
132+
"device": "CPU",
133+
}
134+
)
135+
+ ")"
136+
)

pandas/core/interchange/column.py

+56-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
)
47

58
import numpy as np
69

@@ -9,15 +12,18 @@
912
from pandas.errors import NoBufferPresent
1013
from pandas.util._decorators import cache_readonly
1114

12-
from pandas.core.dtypes.dtypes import (
15+
from pandas.core.dtypes.dtypes import BaseMaskedDtype
16+
17+
import pandas as pd
18+
from pandas import (
1319
ArrowDtype,
14-
BaseMaskedDtype,
1520
DatetimeTZDtype,
1621
)
17-
18-
import pandas as pd
1922
from pandas.api.types import is_string_dtype
20-
from pandas.core.interchange.buffer import PandasBuffer
23+
from pandas.core.interchange.buffer import (
24+
PandasBuffer,
25+
PandasBufferPyarrow,
26+
)
2127
from pandas.core.interchange.dataframe_protocol import (
2228
Column,
2329
ColumnBuffers,
@@ -30,6 +36,9 @@
3036
dtype_to_arrow_c_fmt,
3137
)
3238

39+
if TYPE_CHECKING:
40+
from pandas.core.interchange.dataframe_protocol import Buffer
41+
3342
_NP_KINDS = {
3443
"i": DtypeKind.INT,
3544
"u": DtypeKind.UINT,
@@ -157,6 +166,16 @@ def _dtype_from_pandasdtype(self, dtype) -> tuple[DtypeKind, int, str, str]:
157166
else:
158167
byteorder = dtype.byteorder
159168

169+
if dtype == "bool[pyarrow]":
170+
# return early to avoid the `* 8` below, as this is a bitmask
171+
# rather than a bytemask
172+
return (
173+
kind,
174+
dtype.itemsize, # pyright: ignore[reportGeneralTypeIssues]
175+
ArrowCTypes.BOOL,
176+
byteorder,
177+
)
178+
160179
return kind, dtype.itemsize * 8, dtype_to_arrow_c_fmt(dtype), byteorder
161180

162181
@property
@@ -194,6 +213,12 @@ def describe_null(self):
194213
column_null_dtype = ColumnNullType.USE_BYTEMASK
195214
null_value = 1
196215
return column_null_dtype, null_value
216+
if isinstance(self._col.dtype, ArrowDtype):
217+
# We already rechunk (if necessary / allowed) upon initialization, so this
218+
# is already single-chunk by the time we get here.
219+
if self._col.array._pa_array.chunks[0].buffers()[0] is None: # type: ignore[attr-defined]
220+
return ColumnNullType.NON_NULLABLE, None
221+
return ColumnNullType.USE_BITMASK, 0
197222
kind = self.dtype[0]
198223
try:
199224
null, value = _NULL_DESCRIPTION[kind]
@@ -278,10 +303,11 @@ def get_buffers(self) -> ColumnBuffers:
278303

279304
def _get_data_buffer(
280305
self,
281-
) -> tuple[PandasBuffer, Any]: # Any is for self.dtype tuple
306+
) -> tuple[Buffer, tuple[DtypeKind, int, str, str]]:
282307
"""
283308
Return the buffer containing the data and the buffer's associated dtype.
284309
"""
310+
buffer: Buffer
285311
if self.dtype[0] in (
286312
DtypeKind.INT,
287313
DtypeKind.UINT,
@@ -291,18 +317,25 @@ def _get_data_buffer(
291317
):
292318
# self.dtype[2] is an ArrowCTypes.TIMESTAMP where the tz will make
293319
# it longer than 4 characters
320+
dtype = self.dtype
294321
if self.dtype[0] == DtypeKind.DATETIME and len(self.dtype[2]) > 4:
295322
np_arr = self._col.dt.tz_convert(None).to_numpy()
296323
else:
297324
arr = self._col.array
298325
if isinstance(self._col.dtype, BaseMaskedDtype):
299326
np_arr = arr._data # type: ignore[attr-defined]
300327
elif isinstance(self._col.dtype, ArrowDtype):
301-
raise NotImplementedError("ArrowDtype not handled yet")
328+
# We already rechunk (if necessary / allowed) upon initialization,
329+
# so this is already single-chunk by the time we get here.
330+
arr = arr._pa_array.chunks[0] # type: ignore[attr-defined]
331+
buffer = PandasBufferPyarrow(
332+
arr.buffers()[1], # type: ignore[attr-defined]
333+
length=len(arr),
334+
)
335+
return buffer, dtype
302336
else:
303337
np_arr = arr._ndarray # type: ignore[attr-defined]
304338
buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
305-
dtype = self.dtype
306339
elif self.dtype[0] == DtypeKind.CATEGORICAL:
307340
codes = self._col.values._codes
308341
buffer = PandasBuffer(codes, allow_copy=self._allow_copy)
@@ -330,13 +363,26 @@ def _get_data_buffer(
330363

331364
return buffer, dtype
332365

333-
def _get_validity_buffer(self) -> tuple[PandasBuffer, Any]:
366+
def _get_validity_buffer(self) -> tuple[Buffer, Any] | None:
334367
"""
335368
Return the buffer containing the mask values indicating missing data and
336369
the buffer's associated dtype.
337370
Raises NoBufferPresent if null representation is not a bit or byte mask.
338371
"""
339372
null, invalid = self.describe_null
373+
buffer: Buffer
374+
if isinstance(self._col.dtype, ArrowDtype):
375+
# We already rechunk (if necessary / allowed) upon initialization, so this
376+
# is already single-chunk by the time we get here.
377+
arr = self._col.array._pa_array.chunks[0] # type: ignore[attr-defined]
378+
dtype = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, Endianness.NATIVE)
379+
if arr.buffers()[0] is None:
380+
return None
381+
buffer = PandasBufferPyarrow(
382+
arr.buffers()[0],
383+
length=len(arr),
384+
)
385+
return buffer, dtype
340386

341387
if isinstance(self._col.dtype, BaseMaskedDtype):
342388
mask = self._col.array._mask # type: ignore[attr-defined]

pandas/core/interchange/dataframe.py

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pandas.core.interchange.column import PandasColumn
77
from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg
8+
from pandas.core.interchange.utils import maybe_rechunk
89

910
if TYPE_CHECKING:
1011
from collections.abc import (
@@ -34,6 +35,10 @@ def __init__(self, df: DataFrame, allow_copy: bool = True) -> None:
3435
"""
3536
self._df = df.rename(columns=str, copy=False)
3637
self._allow_copy = allow_copy
38+
for i, _col in enumerate(self._df.columns):
39+
rechunked = maybe_rechunk(self._df.iloc[:, i], allow_copy=allow_copy)
40+
if rechunked is not None:
41+
self._df.isetitem(i, rechunked)
3742

3843
def __dataframe__(
3944
self, nan_as_null: bool = False, allow_copy: bool = True

pandas/core/interchange/from_dataframe.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,14 @@ def string_column_to_ndarray(col: Column) -> tuple[np.ndarray, Any]:
295295

296296
null_pos = None
297297
if null_kind in (ColumnNullType.USE_BITMASK, ColumnNullType.USE_BYTEMASK):
298-
assert buffers["validity"], "Validity buffers cannot be empty for masks"
299-
valid_buff, valid_dtype = buffers["validity"]
300-
null_pos = buffer_to_ndarray(
301-
valid_buff, valid_dtype, offset=col.offset, length=col.size()
302-
)
303-
if sentinel_val == 0:
304-
null_pos = ~null_pos
298+
validity = buffers["validity"]
299+
if validity is not None:
300+
valid_buff, valid_dtype = validity
301+
null_pos = buffer_to_ndarray(
302+
valid_buff, valid_dtype, offset=col.offset, length=col.size()
303+
)
304+
if sentinel_val == 0:
305+
null_pos = ~null_pos
305306

306307
# Assemble the strings from the code units
307308
str_list: list[None | float | str] = [None] * col.size()
@@ -486,6 +487,8 @@ def set_nulls(
486487
np.ndarray or pd.Series
487488
Data with the nulls being set.
488489
"""
490+
if validity is None:
491+
return data
489492
null_kind, sentinel_val = col.describe_null
490493
null_pos = None
491494

pandas/core/interchange/utils.py

+28
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
DatetimeTZDtype,
1717
)
1818

19+
import pandas as pd
20+
1921
if typing.TYPE_CHECKING:
2022
from pandas._typing import DtypeObj
2123

@@ -145,3 +147,29 @@ def dtype_to_arrow_c_fmt(dtype: DtypeObj) -> str:
145147
raise NotImplementedError(
146148
f"Conversion of {dtype} to Arrow C format string is not implemented."
147149
)
150+
151+
152+
def maybe_rechunk(series: pd.Series, *, allow_copy: bool) -> pd.Series | None:
153+
"""
154+
Rechunk a multi-chunk pyarrow array into a single-chunk array, if necessary.
155+
156+
- Returns `None` if the input series is not backed by a multi-chunk pyarrow array
157+
(and so doesn't need rechunking)
158+
- Returns a single-chunk-backed-Series if the input is backed by a multi-chunk
159+
pyarrow array and `allow_copy` is `True`.
160+
- Raises a `RuntimeError` if `allow_copy` is `False` and input is a
161+
based by a multi-chunk pyarrow array.
162+
"""
163+
if not isinstance(series.dtype, pd.ArrowDtype):
164+
return None
165+
chunked_array = series.array._pa_array # type: ignore[attr-defined]
166+
if len(chunked_array.chunks) == 1:
167+
return None
168+
if not allow_copy:
169+
raise RuntimeError(
170+
"Found multi-chunk pyarrow array, but `allow_copy` is False. "
171+
"Please rechunk the array before calling this function, or set "
172+
"`allow_copy=True`."
173+
)
174+
arr = chunked_array.combine_chunks()
175+
return pd.Series(arr, dtype=series.dtype, name=series.name, index=series.index)

0 commit comments

Comments
 (0)