1
1
from __future__ import annotations
2
2
3
- from typing import Any
3
+ from typing import (
4
+ TYPE_CHECKING ,
5
+ Any ,
6
+ )
4
7
5
8
import numpy as np
6
9
9
12
from pandas .errors import NoBufferPresent
10
13
from pandas .util ._decorators import cache_readonly
11
14
12
- from pandas .core .dtypes .dtypes import (
15
+ from pandas .core .dtypes .dtypes import BaseMaskedDtype
16
+
17
+ import pandas as pd
18
+ from pandas import (
13
19
ArrowDtype ,
14
- BaseMaskedDtype ,
15
20
DatetimeTZDtype ,
16
21
)
17
-
18
- import pandas as pd
19
22
from pandas .api .types import is_string_dtype
20
- from pandas .core .interchange .buffer import PandasBuffer
23
+ from pandas .core .interchange .buffer import (
24
+ PandasBuffer ,
25
+ PandasBufferPyarrow ,
26
+ )
21
27
from pandas .core .interchange .dataframe_protocol import (
22
28
Column ,
23
29
ColumnBuffers ,
30
36
dtype_to_arrow_c_fmt ,
31
37
)
32
38
39
+ if TYPE_CHECKING :
40
+ from pandas .core .interchange .dataframe_protocol import Buffer
41
+
33
42
_NP_KINDS = {
34
43
"i" : DtypeKind .INT ,
35
44
"u" : DtypeKind .UINT ,
@@ -157,6 +166,14 @@ def _dtype_from_pandasdtype(self, dtype) -> tuple[DtypeKind, int, str, str]:
157
166
else :
158
167
byteorder = dtype .byteorder
159
168
169
+ if dtype == "bool[pyarrow]" :
170
+ return (
171
+ kind ,
172
+ dtype .itemsize , # pyright: ignore[reportAttributeAccessIssue]
173
+ ArrowCTypes .BOOL ,
174
+ byteorder ,
175
+ )
176
+
160
177
return kind , dtype .itemsize * 8 , dtype_to_arrow_c_fmt (dtype ), byteorder
161
178
162
179
@property
@@ -194,6 +211,13 @@ def describe_null(self):
194
211
column_null_dtype = ColumnNullType .USE_BYTEMASK
195
212
null_value = 1
196
213
return column_null_dtype , null_value
214
+ if isinstance (self ._col .dtype , ArrowDtype ):
215
+ if all (
216
+ chunk .buffers ()[0 ] is None
217
+ for chunk in self ._col .array ._pa_array .chunks # type: ignore[attr-defined]
218
+ ):
219
+ return ColumnNullType .NON_NULLABLE , None
220
+ return ColumnNullType .USE_BITMASK , 0
197
221
kind = self .dtype [0 ]
198
222
try :
199
223
null , value = _NULL_DESCRIPTION [kind ]
@@ -278,7 +302,7 @@ def get_buffers(self) -> ColumnBuffers:
278
302
279
303
def _get_data_buffer (
280
304
self ,
281
- ) -> tuple [PandasBuffer , Any ]: # Any is for self.dtype tuple
305
+ ) -> tuple [Buffer , tuple [ DtypeKind , int , str , str ]]:
282
306
"""
283
307
Return the buffer containing the data and the buffer's associated dtype.
284
308
"""
@@ -289,7 +313,7 @@ def _get_data_buffer(
289
313
np_arr = self ._col .dt .tz_convert (None ).to_numpy ()
290
314
else :
291
315
np_arr = self ._col .to_numpy ()
292
- buffer = PandasBuffer (np_arr , allow_copy = self ._allow_copy )
316
+ buffer : Buffer = PandasBuffer (np_arr , allow_copy = self ._allow_copy )
293
317
dtype = (
294
318
DtypeKind .INT ,
295
319
64 ,
@@ -302,15 +326,27 @@ def _get_data_buffer(
302
326
DtypeKind .FLOAT ,
303
327
DtypeKind .BOOL ,
304
328
):
329
+ dtype = self .dtype
305
330
arr = self ._col .array
331
+ if isinstance (self ._col .dtype , ArrowDtype ):
332
+ buffer = PandasBufferPyarrow (
333
+ arr ._pa_array , # type: ignore[attr-defined]
334
+ is_validity = False ,
335
+ allow_copy = self ._allow_copy ,
336
+ )
337
+ if self .dtype [0 ] == DtypeKind .BOOL :
338
+ dtype = (
339
+ DtypeKind .BOOL ,
340
+ 1 ,
341
+ ArrowCTypes .BOOL ,
342
+ Endianness .NATIVE ,
343
+ )
344
+ return buffer , dtype
306
345
if isinstance (self ._col .dtype , BaseMaskedDtype ):
307
346
np_arr = arr ._data # type: ignore[attr-defined]
308
- elif isinstance (self ._col .dtype , ArrowDtype ):
309
- raise NotImplementedError ("ArrowDtype not handled yet" )
310
347
else :
311
348
np_arr = arr ._ndarray # type: ignore[attr-defined]
312
349
buffer = PandasBuffer (np_arr , allow_copy = self ._allow_copy )
313
- dtype = self .dtype
314
350
elif self .dtype [0 ] == DtypeKind .CATEGORICAL :
315
351
codes = self ._col .values ._codes
316
352
buffer = PandasBuffer (codes , allow_copy = self ._allow_copy )
@@ -343,14 +379,29 @@ def _get_data_buffer(
343
379
344
380
return buffer , dtype
345
381
346
- def _get_validity_buffer (self ) -> tuple [PandasBuffer , Any ]:
382
+ def _get_validity_buffer (self ) -> tuple [Buffer , Any ] | None :
347
383
"""
348
384
Return the buffer containing the mask values indicating missing data and
349
385
the buffer's associated dtype.
350
386
Raises NoBufferPresent if null representation is not a bit or byte mask.
351
387
"""
352
388
null , invalid = self .describe_null
353
389
390
+ if isinstance (self ._col .dtype , ArrowDtype ):
391
+ arr = self ._col .array
392
+ dtype = (DtypeKind .BOOL , 1 , ArrowCTypes .BOOL , Endianness .NATIVE )
393
+ if all (
394
+ chunk .buffers ()[0 ] is None
395
+ for chunk in arr ._pa_array .chunks # type: ignore[attr-defined]
396
+ ):
397
+ return None
398
+ buffer : Buffer = PandasBufferPyarrow (
399
+ arr ._pa_array , # type: ignore[attr-defined]
400
+ is_validity = True ,
401
+ allow_copy = self ._allow_copy ,
402
+ )
403
+ return buffer , dtype
404
+
354
405
if isinstance (self ._col .dtype , BaseMaskedDtype ):
355
406
mask = self ._col .array ._mask # type: ignore[attr-defined]
356
407
buffer = PandasBuffer (mask )
0 commit comments