Skip to content

Commit 7dea5ae

Browse files
authored
REF: implement ArrowExtensionArray base class (#46102)
1 parent c2188de commit 7dea5ae

File tree

6 files changed

+102
-94
lines changed

6 files changed

+102
-94
lines changed

pandas/core/arrays/_arrow_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import json
24

35
import numpy as np

pandas/core/arrays/_mixins.py

+89
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
npt,
2929
type_t,
3030
)
31+
from pandas.compat import pa_version_under2p0
3132
from pandas.errors import AbstractMethodError
3233
from pandas.util._decorators import doc
3334
from pandas.util._validators import (
@@ -66,6 +67,8 @@
6667

6768
if TYPE_CHECKING:
6869

70+
import pyarrow as pa
71+
6972
from pandas._typing import (
7073
NumpySorter,
7174
NumpyValueArrayLike,
@@ -508,3 +511,89 @@ def _empty(
508511
arr = cls._from_sequence([], dtype=dtype)
509512
backing = np.empty(shape, dtype=arr._ndarray.dtype)
510513
return arr._from_backing_data(backing)
514+
515+
516+
ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")
517+
518+
519+
class ArrowExtensionArray(ExtensionArray):
520+
"""
521+
Base class for ExtensionArray backed by Arrow array.
522+
"""
523+
524+
_data: pa.ChunkedArray
525+
526+
def __init__(self, values: pa.ChunkedArray):
527+
self._data = values
528+
529+
def __arrow_array__(self, type=None):
530+
"""Convert myself to a pyarrow Array or ChunkedArray."""
531+
return self._data
532+
533+
def equals(self, other) -> bool:
534+
if not isinstance(other, ArrowExtensionArray):
535+
return False
536+
# I'm told that pyarrow makes __eq__ behave like pandas' equals;
537+
# TODO: is this documented somewhere?
538+
return self._data == other._data
539+
540+
@property
541+
def nbytes(self) -> int:
542+
"""
543+
The number of bytes needed to store this object in memory.
544+
"""
545+
return self._data.nbytes
546+
547+
def __len__(self) -> int:
548+
"""
549+
Length of this array.
550+
551+
Returns
552+
-------
553+
length : int
554+
"""
555+
return len(self._data)
556+
557+
def isna(self) -> npt.NDArray[np.bool_]:
558+
"""
559+
Boolean NumPy array indicating if each value is missing.
560+
561+
This should return a 1-D array the same length as 'self'.
562+
"""
563+
if pa_version_under2p0:
564+
return self._data.is_null().to_pandas().values
565+
else:
566+
return self._data.is_null().to_numpy()
567+
568+
def copy(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
569+
"""
570+
Return a shallow copy of the array.
571+
572+
Underlying ChunkedArray is immutable, so a deep copy is unnecessary.
573+
574+
Returns
575+
-------
576+
type(self)
577+
"""
578+
return type(self)(self._data)
579+
580+
@classmethod
581+
def _concat_same_type(
582+
cls: type[ArrowExtensionArrayT], to_concat
583+
) -> ArrowExtensionArrayT:
584+
"""
585+
Concatenate multiple ArrowExtensionArrays.
586+
587+
Parameters
588+
----------
589+
to_concat : sequence of ArrowExtensionArrays
590+
591+
Returns
592+
-------
593+
ArrowExtensionArray
594+
"""
595+
import pyarrow as pa
596+
597+
chunks = [array for ea in to_concat for array in ea._data.iterchunks()]
598+
arr = pa.chunked_array(chunks)
599+
return cls(arr)

pandas/core/arrays/string_.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@
3939
from pandas.core import ops
4040
from pandas.core.array_algos import masked_reductions
4141
from pandas.core.arrays import (
42+
ExtensionArray,
4243
FloatingArray,
4344
IntegerArray,
4445
PandasArray,
4546
)
46-
from pandas.core.arrays.base import ExtensionArray
4747
from pandas.core.arrays.floating import FloatingDtype
4848
from pandas.core.arrays.integer import IntegerDtype
4949
from pandas.core.construction import extract_array
@@ -224,6 +224,10 @@ def __from_arrow__(
224224

225225

226226
class BaseStringArray(ExtensionArray):
227+
"""
228+
Mixin class for StringArray, ArrowStringArray.
229+
"""
230+
227231
pass
228232

229233

pandas/core/arrays/string_arrow.py

+4-62
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from pandas.core.dtypes.missing import isna
4949

5050
from pandas.core.arraylike import OpsMixin
51+
from pandas.core.arrays._mixins import ArrowExtensionArray
5152
from pandas.core.arrays.base import ExtensionArray
5253
from pandas.core.arrays.boolean import BooleanDtype
5354
from pandas.core.arrays.integer import Int64Dtype
@@ -94,7 +95,9 @@ def _chk_pyarrow_available() -> None:
9495
# fallback for the ones that pyarrow doesn't yet support
9596

9697

97-
class ArrowStringArray(OpsMixin, BaseStringArray, ObjectStringArrayMixin):
98+
class ArrowStringArray(
99+
OpsMixin, ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin
100+
):
98101
"""
99102
Extension array for string data in a ``pyarrow.ChunkedArray``.
100103
@@ -191,10 +194,6 @@ def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
191194
"""Correctly construct numpy arrays when passed to `np.asarray()`."""
192195
return self.to_numpy(dtype=dtype)
193196

194-
def __arrow_array__(self, type=None):
195-
"""Convert myself to a pyarrow Array or ChunkedArray."""
196-
return self._data
197-
198197
def to_numpy(
199198
self,
200199
dtype: npt.DTypeLike | None = None,
@@ -216,16 +215,6 @@ def to_numpy(
216215
result[mask] = na_value
217216
return result
218217

219-
def __len__(self) -> int:
220-
"""
221-
Length of this array.
222-
223-
Returns
224-
-------
225-
length : int
226-
"""
227-
return len(self._data)
228-
229218
@doc(ExtensionArray.factorize)
230219
def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
231220
encoded = self._data.dictionary_encode()
@@ -243,25 +232,6 @@ def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
243232

244233
return indices.values, uniques
245234

246-
@classmethod
247-
def _concat_same_type(cls, to_concat) -> ArrowStringArray:
248-
"""
249-
Concatenate multiple ArrowStringArray.
250-
251-
Parameters
252-
----------
253-
to_concat : sequence of ArrowStringArray
254-
255-
Returns
256-
-------
257-
ArrowStringArray
258-
"""
259-
return cls(
260-
pa.chunked_array(
261-
[array for ea in to_concat for array in ea._data.iterchunks()]
262-
)
263-
)
264-
265235
@overload
266236
def __getitem__(self, item: ScalarIndexer) -> ArrowStringScalarOrNAT:
267237
...
@@ -342,34 +312,6 @@ def _as_pandas_scalar(self, arrow_scalar: pa.Scalar):
342312
else:
343313
return scalar
344314

345-
@property
346-
def nbytes(self) -> int:
347-
"""
348-
The number of bytes needed to store this object in memory.
349-
"""
350-
return self._data.nbytes
351-
352-
def isna(self) -> np.ndarray:
353-
"""
354-
Boolean NumPy array indicating if each value is missing.
355-
356-
This should return a 1-D array the same length as 'self'.
357-
"""
358-
# TODO: Implement .to_numpy for ChunkedArray
359-
return self._data.is_null().to_pandas().values
360-
361-
def copy(self) -> ArrowStringArray:
362-
"""
363-
Return a shallow copy of the array.
364-
365-
Underlying ChunkedArray is immutable, so a deep copy is unnecessary.
366-
367-
Returns
368-
-------
369-
ArrowStringArray
370-
"""
371-
return type(self)(self._data)
372-
373315
def _cmp_method(self, other, op):
374316
from pandas.arrays import BooleanArray
375317

pandas/tests/extension/arrow/arrays.py

+2-22
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99
from __future__ import annotations
1010

11-
import copy
1211
import itertools
1312
import operator
1413

@@ -19,13 +18,13 @@
1918

2019
import pandas as pd
2120
from pandas.api.extensions import (
22-
ExtensionArray,
2321
ExtensionDtype,
2422
register_extension_dtype,
2523
take,
2624
)
2725
from pandas.api.types import is_scalar
2826
from pandas.core.arraylike import OpsMixin
27+
from pandas.core.arrays._mixins import ArrowExtensionArray as _ArrowExtensionArray
2928
from pandas.core.construction import extract_array
3029

3130

@@ -73,7 +72,7 @@ def construct_array_type(cls) -> type_t[ArrowStringArray]:
7372
return ArrowStringArray
7473

7574

76-
class ArrowExtensionArray(OpsMixin, ExtensionArray):
75+
class ArrowExtensionArray(OpsMixin, _ArrowExtensionArray):
7776
_data: pa.ChunkedArray
7877

7978
@classmethod
@@ -111,9 +110,6 @@ def __getitem__(self, item):
111110
vals = self._data.to_pandas()[item]
112111
return type(self)._from_sequence(vals)
113112

114-
def __len__(self):
115-
return len(self._data)
116-
117113
def astype(self, dtype, copy=True):
118114
# needed to fix this astype for the Series constructor.
119115
if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
@@ -142,19 +138,6 @@ def __eq__(self, other):
142138

143139
return self._logical_method(other, operator.eq)
144140

145-
@property
146-
def nbytes(self) -> int:
147-
return sum(
148-
x.size
149-
for chunk in self._data.chunks
150-
for x in chunk.buffers()
151-
if x is not None
152-
)
153-
154-
def isna(self):
155-
nas = pd.isna(self._data.to_pandas())
156-
return type(self)._from_sequence(nas)
157-
158141
def take(self, indices, allow_fill=False, fill_value=None):
159142
data = self._data.to_pandas()
160143
data = extract_array(data, extract_numpy=True)
@@ -165,9 +148,6 @@ def take(self, indices, allow_fill=False, fill_value=None):
165148
result = take(data, indices, fill_value=fill_value, allow_fill=allow_fill)
166149
return self._from_sequence(result, dtype=self.dtype)
167150

168-
def copy(self):
169-
return type(self)(copy.copy(self._data))
170-
171151
@classmethod
172152
def _concat_same_type(cls, to_concat):
173153
chunks = list(itertools.chain.from_iterable(x._data.chunks for x in to_concat))

pandas/tests/extension/arrow/test_bool.py

-9
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,6 @@ def test_contains(self, data, data_missing):
6262

6363

6464
class TestConstructors(BaseArrowTests, base.BaseConstructorsTests):
65-
# seems like some bug in isna on empty BoolArray returning floats.
66-
@pytest.mark.xfail(reason="bad is-na for empty data")
67-
def test_from_sequence_from_cls(self, data):
68-
super().test_from_sequence_from_cls(data)
69-
7065
@pytest.mark.xfail(reason="pa.NULL is not recognised as scalar, GH-33899")
7166
def test_series_constructor_no_data_with_index(self, dtype, na_value):
7267
# pyarrow.lib.ArrowInvalid: only handle 1-dimensional arrays
@@ -77,10 +72,6 @@ def test_series_constructor_scalar_na_with_index(self, dtype, na_value):
7772
# pyarrow.lib.ArrowInvalid: only handle 1-dimensional arrays
7873
super().test_series_constructor_scalar_na_with_index(dtype, na_value)
7974

80-
@pytest.mark.xfail(reason="ufunc 'invert' not supported for the input types")
81-
def test_construct_empty_dataframe(self, dtype):
82-
super().test_construct_empty_dataframe(dtype)
83-
8475
@pytest.mark.xfail(reason="_from_sequence ignores dtype keyword")
8576
def test_empty(self, dtype):
8677
super().test_empty(dtype)

0 commit comments

Comments
 (0)