Skip to content

Commit dc36ce1

Browse files
authored
ENH/TST: Add BaseInterfaceTests tests for ArrowExtensionArray PT2 (#47468)
1 parent e5c7543 commit dc36ce1

File tree

4 files changed

+567
-43
lines changed

4 files changed

+567
-43
lines changed

pandas/core/arrays/arrow/array.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from pandas.core.dtypes.missing import isna
3333

34+
from pandas.core.arraylike import OpsMixin
3435
from pandas.core.arrays.base import ExtensionArray
3536
from pandas.core.indexers import (
3637
check_array_indexer,
@@ -45,13 +46,22 @@
4546
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
4647
from pandas.core.arrays.arrow.dtype import ArrowDtype
4748

49+
ARROW_CMP_FUNCS = {
50+
"eq": pc.equal,
51+
"ne": pc.not_equal,
52+
"lt": pc.less,
53+
"gt": pc.greater,
54+
"le": pc.less_equal,
55+
"ge": pc.greater_equal,
56+
}
57+
4858
if TYPE_CHECKING:
4959
from pandas import Series
5060

5161
ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")
5262

5363

54-
class ArrowExtensionArray(ExtensionArray):
64+
class ArrowExtensionArray(OpsMixin, ExtensionArray):
5565
"""
5666
Base class for ExtensionArray backed by Arrow ChunkedArray.
5767
"""
@@ -179,6 +189,34 @@ def __arrow_array__(self, type=None):
179189
"""Convert myself to a pyarrow ChunkedArray."""
180190
return self._data
181191

192+
def _cmp_method(self, other, op):
193+
from pandas.arrays import BooleanArray
194+
195+
pc_func = ARROW_CMP_FUNCS[op.__name__]
196+
if isinstance(other, ArrowExtensionArray):
197+
result = pc_func(self._data, other._data)
198+
elif isinstance(other, (np.ndarray, list)):
199+
result = pc_func(self._data, other)
200+
elif is_scalar(other):
201+
try:
202+
result = pc_func(self._data, pa.scalar(other))
203+
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
204+
mask = isna(self) | isna(other)
205+
valid = ~mask
206+
result = np.zeros(len(self), dtype="bool")
207+
result[valid] = op(np.array(self)[valid], other)
208+
return BooleanArray(result, mask)
209+
else:
210+
return NotImplementedError(
211+
f"{op.__name__} not implemented for {type(other)}"
212+
)
213+
214+
if pa_version_under2p0:
215+
result = result.to_pandas().values
216+
else:
217+
result = result.to_numpy()
218+
return BooleanArray._from_sequence(result)
219+
182220
def equals(self, other) -> bool:
183221
if not isinstance(other, ArrowExtensionArray):
184222
return False
@@ -589,7 +627,7 @@ def _replace_with_indices(
589627
# fast path for a contiguous set of indices
590628
arrays = [
591629
chunk[:start],
592-
pa.array(value, type=chunk.type),
630+
pa.array(value, type=chunk.type, from_pandas=True),
593631
chunk[stop + 1 :],
594632
]
595633
arrays = [arr for arr in arrays if len(arr)]

pandas/core/arrays/string_arrow.py

+1-39
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
)
3535
from pandas.core.dtypes.missing import isna
3636

37-
from pandas.core.arraylike import OpsMixin
3837
from pandas.core.arrays.arrow import ArrowExtensionArray
3938
from pandas.core.arrays.boolean import BooleanDtype
4039
from pandas.core.arrays.integer import Int64Dtype
@@ -51,15 +50,6 @@
5150

5251
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
5352

54-
ARROW_CMP_FUNCS = {
55-
"eq": pc.equal,
56-
"ne": pc.not_equal,
57-
"lt": pc.less,
58-
"gt": pc.greater,
59-
"le": pc.less_equal,
60-
"ge": pc.greater_equal,
61-
}
62-
6353
ArrowStringScalarOrNAT = Union[str, libmissing.NAType]
6454

6555

@@ -74,9 +64,7 @@ def _chk_pyarrow_available() -> None:
7464
# fallback for the ones that pyarrow doesn't yet support
7565

7666

77-
class ArrowStringArray(
78-
OpsMixin, ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin
79-
):
67+
class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin):
8068
"""
8169
Extension array for string data in a ``pyarrow.ChunkedArray``.
8270
@@ -190,32 +178,6 @@ def to_numpy(
190178
result[mask] = na_value
191179
return result
192180

193-
def _cmp_method(self, other, op):
194-
from pandas.arrays import BooleanArray
195-
196-
pc_func = ARROW_CMP_FUNCS[op.__name__]
197-
if isinstance(other, ArrowStringArray):
198-
result = pc_func(self._data, other._data)
199-
elif isinstance(other, (np.ndarray, list)):
200-
result = pc_func(self._data, other)
201-
elif is_scalar(other):
202-
try:
203-
result = pc_func(self._data, pa.scalar(other))
204-
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
205-
mask = isna(self) | isna(other)
206-
valid = ~mask
207-
result = np.zeros(len(self), dtype="bool")
208-
result[valid] = op(np.array(self)[valid], other)
209-
return BooleanArray(result, mask)
210-
else:
211-
return NotImplemented
212-
213-
if pa_version_under2p0:
214-
result = result.to_pandas().values
215-
else:
216-
result = result.to_numpy()
217-
return BooleanArray._from_sequence(result)
218-
219181
def insert(self, loc: int, item):
220182
if not isinstance(item, str) and item is not libmissing.NA:
221183
raise TypeError("Scalar must be NA or str")

pandas/tests/extension/arrow/arrays.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
take,
2424
)
2525
from pandas.api.types import is_scalar
26-
from pandas.core.arraylike import OpsMixin
2726
from pandas.core.arrays.arrow import ArrowExtensionArray as _ArrowExtensionArray
2827
from pandas.core.construction import extract_array
2928

@@ -72,7 +71,7 @@ def construct_array_type(cls) -> type_t[ArrowStringArray]:
7271
return ArrowStringArray
7372

7473

75-
class ArrowExtensionArray(OpsMixin, _ArrowExtensionArray):
74+
class ArrowExtensionArray(_ArrowExtensionArray):
7675
_data: pa.ChunkedArray
7776

7877
@classmethod

0 commit comments

Comments
 (0)