1
1
from __future__ import annotations
2
2
3
- from typing import Any
3
+ from typing import (
4
+ TYPE_CHECKING ,
5
+ Any ,
6
+ )
4
7
5
8
import numpy as np
6
9
9
12
from pandas .errors import NoBufferPresent
10
13
from pandas .util ._decorators import cache_readonly
11
14
12
- from pandas .core .dtypes .dtypes import (
15
+ from pandas .core .dtypes .dtypes import BaseMaskedDtype
16
+
17
+ import pandas as pd
18
+ from pandas import (
13
19
ArrowDtype ,
14
- BaseMaskedDtype ,
15
20
DatetimeTZDtype ,
16
21
)
17
-
18
- import pandas as pd
19
22
from pandas .api .types import is_string_dtype
20
- from pandas .core .interchange .buffer import PandasBuffer
23
+ from pandas .core .interchange .buffer import (
24
+ PandasBuffer ,
25
+ PandasBufferPyarrow ,
26
+ )
21
27
from pandas .core .interchange .dataframe_protocol import (
22
28
Column ,
23
29
ColumnBuffers ,
30
36
dtype_to_arrow_c_fmt ,
31
37
)
32
38
39
+ if TYPE_CHECKING :
40
+ from pandas .core .interchange .dataframe_protocol import Buffer
41
+
33
42
_NP_KINDS = {
34
43
"i" : DtypeKind .INT ,
35
44
"u" : DtypeKind .UINT ,
@@ -157,6 +166,16 @@ def _dtype_from_pandasdtype(self, dtype) -> tuple[DtypeKind, int, str, str]:
157
166
else :
158
167
byteorder = dtype .byteorder
159
168
169
+ if dtype == "bool[pyarrow]" :
170
+ # return early to avoid the `* 8` below, as this is a bitmask
171
+ # rather than a bytemask
172
+ return (
173
+ kind ,
174
+ dtype .itemsize , # pyright: ignore[reportAttributeAccessIssue]
175
+ ArrowCTypes .BOOL ,
176
+ byteorder ,
177
+ )
178
+
160
179
return kind , dtype .itemsize * 8 , dtype_to_arrow_c_fmt (dtype ), byteorder
161
180
162
181
@property
@@ -194,6 +213,12 @@ def describe_null(self):
194
213
column_null_dtype = ColumnNullType .USE_BYTEMASK
195
214
null_value = 1
196
215
return column_null_dtype , null_value
216
+ if isinstance (self ._col .dtype , ArrowDtype ):
217
+ # We already rechunk (if necessary / allowed) upon initialization, so this
218
+ # is already single-chunk by the time we get here.
219
+ if self ._col .array ._pa_array .chunks [0 ].buffers ()[0 ] is None : # type: ignore[attr-defined]
220
+ return ColumnNullType .NON_NULLABLE , None
221
+ return ColumnNullType .USE_BITMASK , 0
197
222
kind = self .dtype [0 ]
198
223
try :
199
224
null , value = _NULL_DESCRIPTION [kind ]
@@ -278,10 +303,11 @@ def get_buffers(self) -> ColumnBuffers:
278
303
279
304
def _get_data_buffer (
280
305
self ,
281
- ) -> tuple [PandasBuffer , Any ]: # Any is for self.dtype tuple
306
+ ) -> tuple [Buffer , tuple [ DtypeKind , int , str , str ]]:
282
307
"""
283
308
Return the buffer containing the data and the buffer's associated dtype.
284
309
"""
310
+ buffer : Buffer
285
311
if self .dtype [0 ] == DtypeKind .DATETIME :
286
312
# self.dtype[2] is an ArrowCTypes.TIMESTAMP where the tz will make
287
313
# it longer than 4 characters
@@ -302,15 +328,22 @@ def _get_data_buffer(
302
328
DtypeKind .FLOAT ,
303
329
DtypeKind .BOOL ,
304
330
):
331
+ dtype = self .dtype
305
332
arr = self ._col .array
333
+ if isinstance (self ._col .dtype , ArrowDtype ):
334
+ # We already rechunk (if necessary / allowed) upon initialization, so
335
+ # this is already single-chunk by the time we get here.
336
+ arr = arr ._pa_array .chunks [0 ] # type: ignore[attr-defined]
337
+ buffer = PandasBufferPyarrow (
338
+ arr .buffers ()[1 ], # type: ignore[attr-defined]
339
+ length = len (arr ),
340
+ )
341
+ return buffer , dtype
306
342
if isinstance (self ._col .dtype , BaseMaskedDtype ):
307
343
np_arr = arr ._data # type: ignore[attr-defined]
308
- elif isinstance (self ._col .dtype , ArrowDtype ):
309
- raise NotImplementedError ("ArrowDtype not handled yet" )
310
344
else :
311
345
np_arr = arr ._ndarray # type: ignore[attr-defined]
312
346
buffer = PandasBuffer (np_arr , allow_copy = self ._allow_copy )
313
- dtype = self .dtype
314
347
elif self .dtype [0 ] == DtypeKind .CATEGORICAL :
315
348
codes = self ._col .values ._codes
316
349
buffer = PandasBuffer (codes , allow_copy = self ._allow_copy )
@@ -343,13 +376,26 @@ def _get_data_buffer(
343
376
344
377
return buffer , dtype
345
378
346
- def _get_validity_buffer (self ) -> tuple [PandasBuffer , Any ]:
379
+ def _get_validity_buffer (self ) -> tuple [Buffer , Any ] | None :
347
380
"""
348
381
Return the buffer containing the mask values indicating missing data and
349
382
the buffer's associated dtype.
350
383
Raises NoBufferPresent if null representation is not a bit or byte mask.
351
384
"""
352
385
null , invalid = self .describe_null
386
+ buffer : Buffer
387
+ if isinstance (self ._col .dtype , ArrowDtype ):
388
+ # We already rechunk (if necessary / allowed) upon initialization, so this
389
+ # is already single-chunk by the time we get here.
390
+ arr = self ._col .array ._pa_array .chunks [0 ] # type: ignore[attr-defined]
391
+ dtype = (DtypeKind .BOOL , 1 , ArrowCTypes .BOOL , Endianness .NATIVE )
392
+ if arr .buffers ()[0 ] is None :
393
+ return None
394
+ buffer = PandasBufferPyarrow (
395
+ arr .buffers ()[0 ],
396
+ length = len (arr ),
397
+ )
398
+ return buffer , dtype
353
399
354
400
if isinstance (self ._col .dtype , BaseMaskedDtype ):
355
401
mask = self ._col .array ._mask # type: ignore[attr-defined]
0 commit comments