1
1
from __future__ import annotations
2
2
3
- from typing import Any
3
+ from typing import (
4
+ TYPE_CHECKING ,
5
+ Any ,
6
+ )
4
7
5
8
import numpy as np
6
9
9
12
from pandas .errors import NoBufferPresent
10
13
from pandas .util ._decorators import cache_readonly
11
14
12
- from pandas .core .dtypes .dtypes import (
15
+ from pandas .core .dtypes .dtypes import BaseMaskedDtype
16
+
17
+ import pandas as pd
18
+ from pandas import (
13
19
ArrowDtype ,
14
- BaseMaskedDtype ,
15
20
DatetimeTZDtype ,
16
21
)
17
-
18
- import pandas as pd
19
22
from pandas .api .types import is_string_dtype
20
- from pandas .core .interchange .buffer import PandasBuffer
23
+ from pandas .core .interchange .buffer import (
24
+ PandasBuffer ,
25
+ PandasBufferPyarrow ,
26
+ )
21
27
from pandas .core .interchange .dataframe_protocol import (
22
28
Column ,
23
29
ColumnBuffers ,
30
36
dtype_to_arrow_c_fmt ,
31
37
)
32
38
39
+ if TYPE_CHECKING :
40
+ from pandas .core .interchange .dataframe_protocol import Buffer
41
+
33
42
_NP_KINDS = {
34
43
"i" : DtypeKind .INT ,
35
44
"u" : DtypeKind .UINT ,
@@ -157,6 +166,16 @@ def _dtype_from_pandasdtype(self, dtype) -> tuple[DtypeKind, int, str, str]:
157
166
else :
158
167
byteorder = dtype .byteorder
159
168
169
+ if dtype == "bool[pyarrow]" :
170
+ # return early to avoid the `* 8` below, as this is a bitmask
171
+ # rather than a bytemask
172
+ return (
173
+ kind ,
174
+ dtype .itemsize , # pyright: ignore[reportGeneralTypeIssues]
175
+ ArrowCTypes .BOOL ,
176
+ byteorder ,
177
+ )
178
+
160
179
return kind , dtype .itemsize * 8 , dtype_to_arrow_c_fmt (dtype ), byteorder
161
180
162
181
@property
@@ -194,6 +213,12 @@ def describe_null(self):
194
213
column_null_dtype = ColumnNullType .USE_BYTEMASK
195
214
null_value = 1
196
215
return column_null_dtype , null_value
216
+ if isinstance (self ._col .dtype , ArrowDtype ):
217
+ # We already rechunk (if necessary / allowed) upon initialization, so this
218
+ # is already single-chunk by the time we get here.
219
+ if self ._col .array ._pa_array .chunks [0 ].buffers ()[0 ] is None : # type: ignore[attr-defined]
220
+ return ColumnNullType .NON_NULLABLE , None
221
+ return ColumnNullType .USE_BITMASK , 0
197
222
kind = self .dtype [0 ]
198
223
try :
199
224
null , value = _NULL_DESCRIPTION [kind ]
@@ -278,10 +303,11 @@ def get_buffers(self) -> ColumnBuffers:
278
303
279
304
def _get_data_buffer (
280
305
self ,
281
- ) -> tuple [PandasBuffer , Any ]: # Any is for self.dtype tuple
306
+ ) -> tuple [Buffer , tuple [ DtypeKind , int , str , str ]]:
282
307
"""
283
308
Return the buffer containing the data and the buffer's associated dtype.
284
309
"""
310
+ buffer : Buffer
285
311
if self .dtype [0 ] in (
286
312
DtypeKind .INT ,
287
313
DtypeKind .UINT ,
@@ -291,18 +317,25 @@ def _get_data_buffer(
291
317
):
292
318
# self.dtype[2] is an ArrowCTypes.TIMESTAMP where the tz will make
293
319
# it longer than 4 characters
320
+ dtype = self .dtype
294
321
if self .dtype [0 ] == DtypeKind .DATETIME and len (self .dtype [2 ]) > 4 :
295
322
np_arr = self ._col .dt .tz_convert (None ).to_numpy ()
296
323
else :
297
324
arr = self ._col .array
298
325
if isinstance (self ._col .dtype , BaseMaskedDtype ):
299
326
np_arr = arr ._data # type: ignore[attr-defined]
300
327
elif isinstance (self ._col .dtype , ArrowDtype ):
301
- raise NotImplementedError ("ArrowDtype not handled yet" )
328
+ # We already rechunk (if necessary / allowed) upon initialization,
329
+ # so this is already single-chunk by the time we get here.
330
+ arr = arr ._pa_array .chunks [0 ] # type: ignore[attr-defined]
331
+ buffer = PandasBufferPyarrow (
332
+ arr .buffers ()[1 ], # type: ignore[attr-defined]
333
+ length = len (arr ),
334
+ )
335
+ return buffer , dtype
302
336
else :
303
337
np_arr = arr ._ndarray # type: ignore[attr-defined]
304
338
buffer = PandasBuffer (np_arr , allow_copy = self ._allow_copy )
305
- dtype = self .dtype
306
339
elif self .dtype [0 ] == DtypeKind .CATEGORICAL :
307
340
codes = self ._col .values ._codes
308
341
buffer = PandasBuffer (codes , allow_copy = self ._allow_copy )
@@ -330,13 +363,26 @@ def _get_data_buffer(
330
363
331
364
return buffer , dtype
332
365
333
- def _get_validity_buffer (self ) -> tuple [PandasBuffer , Any ]:
366
+ def _get_validity_buffer (self ) -> tuple [Buffer , Any ] | None :
334
367
"""
335
368
Return the buffer containing the mask values indicating missing data and
336
369
the buffer's associated dtype.
337
370
Raises NoBufferPresent if null representation is not a bit or byte mask.
338
371
"""
339
372
null , invalid = self .describe_null
373
+ buffer : Buffer
374
+ if isinstance (self ._col .dtype , ArrowDtype ):
375
+ # We already rechunk (if necessary / allowed) upon initialization, so this
376
+ # is already single-chunk by the time we get here.
377
+ arr = self ._col .array ._pa_array .chunks [0 ] # type: ignore[attr-defined]
378
+ dtype = (DtypeKind .BOOL , 1 , ArrowCTypes .BOOL , Endianness .NATIVE )
379
+ if arr .buffers ()[0 ] is None :
380
+ return None
381
+ buffer = PandasBufferPyarrow (
382
+ arr .buffers ()[0 ],
383
+ length = len (arr ),
384
+ )
385
+ return buffer , dtype
340
386
341
387
if isinstance (self ._col .dtype , BaseMaskedDtype ):
342
388
mask = self ._col .array ._mask # type: ignore[attr-defined]
0 commit comments