Skip to content

Commit d40c371

Browse files
authored
ENH/TST: Add BaseInterfaceTests tests for ArrowExtensionArray (#47377)
1 parent fa5a604 commit d40c371

File tree

4 files changed

+84
-43
lines changed

4 files changed

+84
-43
lines changed

pandas/core/arrays/arrow/array.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from pandas.core.dtypes.missing import isna
3333

34+
from pandas.core.arraylike import OpsMixin
3435
from pandas.core.arrays.base import ExtensionArray
3536
from pandas.core.indexers import (
3637
check_array_indexer,
@@ -45,13 +46,22 @@
4546
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
4647
from pandas.core.arrays.arrow.dtype import ArrowDtype
4748

49+
ARROW_CMP_FUNCS = {
50+
"eq": pc.equal,
51+
"ne": pc.not_equal,
52+
"lt": pc.less,
53+
"gt": pc.greater,
54+
"le": pc.less_equal,
55+
"ge": pc.greater_equal,
56+
}
57+
4858
if TYPE_CHECKING:
4959
from pandas import Series
5060

5161
ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")
5262

5363

54-
class ArrowExtensionArray(ExtensionArray):
64+
class ArrowExtensionArray(OpsMixin, ExtensionArray):
5565
"""
5666
Base class for ExtensionArray backed by Arrow ChunkedArray.
5767
"""
@@ -179,6 +189,34 @@ def __arrow_array__(self, type=None):
179189
"""Convert myself to a pyarrow ChunkedArray."""
180190
return self._data
181191

192+
def _cmp_method(self, other, op):
193+
from pandas.arrays import BooleanArray
194+
195+
pc_func = ARROW_CMP_FUNCS[op.__name__]
196+
if isinstance(other, ArrowExtensionArray):
197+
result = pc_func(self._data, other._data)
198+
elif isinstance(other, (np.ndarray, list)):
199+
result = pc_func(self._data, other)
200+
elif is_scalar(other):
201+
try:
202+
result = pc_func(self._data, pa.scalar(other))
203+
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
204+
mask = isna(self) | isna(other)
205+
valid = ~mask
206+
result = np.zeros(len(self), dtype="bool")
207+
result[valid] = op(np.array(self)[valid], other)
208+
return BooleanArray(result, mask)
209+
else:
210+
return NotImplementedError(
211+
f"{op.__name__} not implemented for {type(other)}"
212+
)
213+
214+
if pa_version_under2p0:
215+
result = result.to_pandas().values
216+
else:
217+
result = result.to_numpy()
218+
return BooleanArray._from_sequence(result)
219+
182220
def equals(self, other) -> bool:
183221
if not isinstance(other, ArrowExtensionArray):
184222
return False
@@ -581,7 +619,7 @@ def _replace_with_indices(
581619
# fast path for a contiguous set of indices
582620
arrays = [
583621
chunk[:start],
584-
pa.array(value, type=chunk.type),
622+
pa.array(value, type=chunk.type, from_pandas=True),
585623
chunk[stop + 1 :],
586624
]
587625
arrays = [arr for arr in arrays if len(arr)]

pandas/core/arrays/string_arrow.py

+1-39
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
)
3535
from pandas.core.dtypes.missing import isna
3636

37-
from pandas.core.arraylike import OpsMixin
3837
from pandas.core.arrays.arrow import ArrowExtensionArray
3938
from pandas.core.arrays.boolean import BooleanDtype
4039
from pandas.core.arrays.integer import Int64Dtype
@@ -51,15 +50,6 @@
5150

5251
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
5352

54-
ARROW_CMP_FUNCS = {
55-
"eq": pc.equal,
56-
"ne": pc.not_equal,
57-
"lt": pc.less,
58-
"gt": pc.greater,
59-
"le": pc.less_equal,
60-
"ge": pc.greater_equal,
61-
}
62-
6353
ArrowStringScalarOrNAT = Union[str, libmissing.NAType]
6454

6555

@@ -74,9 +64,7 @@ def _chk_pyarrow_available() -> None:
7464
# fallback for the ones that pyarrow doesn't yet support
7565

7666

77-
class ArrowStringArray(
78-
OpsMixin, ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin
79-
):
67+
class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin):
8068
"""
8169
Extension array for string data in a ``pyarrow.ChunkedArray``.
8270
@@ -190,32 +178,6 @@ def to_numpy(
190178
result[mask] = na_value
191179
return result
192180

193-
def _cmp_method(self, other, op):
194-
from pandas.arrays import BooleanArray
195-
196-
pc_func = ARROW_CMP_FUNCS[op.__name__]
197-
if isinstance(other, ArrowStringArray):
198-
result = pc_func(self._data, other._data)
199-
elif isinstance(other, (np.ndarray, list)):
200-
result = pc_func(self._data, other)
201-
elif is_scalar(other):
202-
try:
203-
result = pc_func(self._data, pa.scalar(other))
204-
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
205-
mask = isna(self) | isna(other)
206-
valid = ~mask
207-
result = np.zeros(len(self), dtype="bool")
208-
result[valid] = op(np.array(self)[valid], other)
209-
return BooleanArray(result, mask)
210-
else:
211-
return NotImplemented
212-
213-
if pa_version_under2p0:
214-
result = result.to_pandas().values
215-
else:
216-
result = result.to_numpy()
217-
return BooleanArray._from_sequence(result)
218-
219181
def insert(self, loc: int, item):
220182
if not isinstance(item, str) and item is not libmissing.NA:
221183
raise TypeError("Scalar must be NA or str")

pandas/tests/extension/arrow/arrays.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
take,
2424
)
2525
from pandas.api.types import is_scalar
26-
from pandas.core.arraylike import OpsMixin
2726
from pandas.core.arrays.arrow import ArrowExtensionArray as _ArrowExtensionArray
2827
from pandas.core.construction import extract_array
2928

@@ -72,7 +71,7 @@ def construct_array_type(cls) -> type_t[ArrowStringArray]:
7271
return ArrowStringArray
7372

7473

75-
class ArrowExtensionArray(OpsMixin, _ArrowExtensionArray):
74+
class ArrowExtensionArray(_ArrowExtensionArray):
7675
_data: pa.ChunkedArray
7776

7877
@classmethod

pandas/tests/extension/test_arrow.py

+42
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ def data_missing(data):
9393
return type(data)._from_sequence([None, data[0]])
9494

9595

96+
@pytest.fixture(params=["data", "data_missing"])
97+
def all_data(request, data, data_missing):
98+
"""Parametrized fixture returning 'data' or 'data_missing' integer arrays.
99+
100+
Used to test dtype conversion with and without missing values.
101+
"""
102+
if request.param == "data":
103+
return data
104+
elif request.param == "data_missing":
105+
return data_missing
106+
107+
96108
@pytest.fixture
97109
def na_value():
98110
"""The scalar missing value for this type. Default 'None'"""
@@ -271,6 +283,36 @@ class TestBaseIndex(base.BaseIndexTests):
271283
pass
272284

273285

286+
class TestBaseInterface(base.BaseInterfaceTests):
287+
def test_contains(self, data, data_missing, request):
288+
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
289+
unit = getattr(data.dtype.pyarrow_dtype, "unit", None)
290+
if pa_version_under2p0 and tz not in (None, "UTC") and unit == "us":
291+
request.node.add_marker(
292+
pytest.mark.xfail(
293+
reason=(
294+
f"Not supported by pyarrow < 2.0 "
295+
f"with timestamp type {tz} and {unit}"
296+
)
297+
)
298+
)
299+
super().test_contains(data, data_missing)
300+
301+
@pytest.mark.xfail(reason="pyarrow.ChunkedArray does not support views.")
302+
def test_view(self, data):
303+
super().test_view(data)
304+
305+
306+
class TestBaseMissing(base.BaseMissingTests):
307+
pass
308+
309+
310+
class TestBaseSetitemTests(base.BaseSetitemTests):
311+
@pytest.mark.xfail(reason="GH 45419: pyarrow.ChunkedArray does not support views")
312+
def test_setitem_preserves_views(self, data):
313+
super().test_setitem_preserves_views(data)
314+
315+
274316
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
275317
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
276318
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")

0 commit comments

Comments
 (0)