Skip to content

Commit 71ce89c

Browse files
Revert "ENH/TST: Add BaseInterfaceTests tests for ArrowExtensionArray (#47377)"
This reverts commit d40c371.
1 parent 2f3ac16 commit 71ce89c

File tree

4 files changed

+43
-84
lines changed

4 files changed

+43
-84
lines changed

pandas/core/arrays/arrow/array.py

+2-40
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
)
3232
from pandas.core.dtypes.missing import isna
3333

34-
from pandas.core.arraylike import OpsMixin
3534
from pandas.core.arrays.base import ExtensionArray
3635
from pandas.core.indexers import (
3736
check_array_indexer,
@@ -46,22 +45,13 @@
4645
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
4746
from pandas.core.arrays.arrow.dtype import ArrowDtype
4847

49-
ARROW_CMP_FUNCS = {
50-
"eq": pc.equal,
51-
"ne": pc.not_equal,
52-
"lt": pc.less,
53-
"gt": pc.greater,
54-
"le": pc.less_equal,
55-
"ge": pc.greater_equal,
56-
}
57-
5848
if TYPE_CHECKING:
5949
from pandas import Series
6050

6151
ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")
6252

6353

64-
class ArrowExtensionArray(OpsMixin, ExtensionArray):
54+
class ArrowExtensionArray(ExtensionArray):
6555
"""
6656
Base class for ExtensionArray backed by Arrow ChunkedArray.
6757
"""
@@ -189,34 +179,6 @@ def __arrow_array__(self, type=None):
189179
"""Convert myself to a pyarrow ChunkedArray."""
190180
return self._data
191181

192-
def _cmp_method(self, other, op):
193-
from pandas.arrays import BooleanArray
194-
195-
pc_func = ARROW_CMP_FUNCS[op.__name__]
196-
if isinstance(other, ArrowExtensionArray):
197-
result = pc_func(self._data, other._data)
198-
elif isinstance(other, (np.ndarray, list)):
199-
result = pc_func(self._data, other)
200-
elif is_scalar(other):
201-
try:
202-
result = pc_func(self._data, pa.scalar(other))
203-
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
204-
mask = isna(self) | isna(other)
205-
valid = ~mask
206-
result = np.zeros(len(self), dtype="bool")
207-
result[valid] = op(np.array(self)[valid], other)
208-
return BooleanArray(result, mask)
209-
else:
210-
return NotImplementedError(
211-
f"{op.__name__} not implemented for {type(other)}"
212-
)
213-
214-
if pa_version_under2p0:
215-
result = result.to_pandas().values
216-
else:
217-
result = result.to_numpy()
218-
return BooleanArray._from_sequence(result)
219-
220182
def equals(self, other) -> bool:
221183
if not isinstance(other, ArrowExtensionArray):
222184
return False
@@ -619,7 +581,7 @@ def _replace_with_indices(
619581
# fast path for a contiguous set of indices
620582
arrays = [
621583
chunk[:start],
622-
pa.array(value, type=chunk.type, from_pandas=True),
584+
pa.array(value, type=chunk.type),
623585
chunk[stop + 1 :],
624586
]
625587
arrays = [arr for arr in arrays if len(arr)]

pandas/core/arrays/string_arrow.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from pandas.core.dtypes.missing import isna
3636

37+
from pandas.core.arraylike import OpsMixin
3738
from pandas.core.arrays.arrow import ArrowExtensionArray
3839
from pandas.core.arrays.boolean import BooleanDtype
3940
from pandas.core.arrays.integer import Int64Dtype
@@ -50,6 +51,15 @@
5051

5152
from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
5253

54+
ARROW_CMP_FUNCS = {
55+
"eq": pc.equal,
56+
"ne": pc.not_equal,
57+
"lt": pc.less,
58+
"gt": pc.greater,
59+
"le": pc.less_equal,
60+
"ge": pc.greater_equal,
61+
}
62+
5363
ArrowStringScalarOrNAT = Union[str, libmissing.NAType]
5464

5565

@@ -64,7 +74,9 @@ def _chk_pyarrow_available() -> None:
6474
# fallback for the ones that pyarrow doesn't yet support
6575

6676

67-
class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin):
77+
class ArrowStringArray(
78+
OpsMixin, ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin
79+
):
6880
"""
6981
Extension array for string data in a ``pyarrow.ChunkedArray``.
7082
@@ -178,6 +190,32 @@ def to_numpy(
178190
result[mask] = na_value
179191
return result
180192

193+
def _cmp_method(self, other, op):
194+
from pandas.arrays import BooleanArray
195+
196+
pc_func = ARROW_CMP_FUNCS[op.__name__]
197+
if isinstance(other, ArrowStringArray):
198+
result = pc_func(self._data, other._data)
199+
elif isinstance(other, (np.ndarray, list)):
200+
result = pc_func(self._data, other)
201+
elif is_scalar(other):
202+
try:
203+
result = pc_func(self._data, pa.scalar(other))
204+
except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid):
205+
mask = isna(self) | isna(other)
206+
valid = ~mask
207+
result = np.zeros(len(self), dtype="bool")
208+
result[valid] = op(np.array(self)[valid], other)
209+
return BooleanArray(result, mask)
210+
else:
211+
return NotImplemented
212+
213+
if pa_version_under2p0:
214+
result = result.to_pandas().values
215+
else:
216+
result = result.to_numpy()
217+
return BooleanArray._from_sequence(result)
218+
181219
def insert(self, loc: int, item):
182220
if not isinstance(item, str) and item is not libmissing.NA:
183221
raise TypeError("Scalar must be NA or str")

pandas/tests/extension/arrow/arrays.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
take,
2424
)
2525
from pandas.api.types import is_scalar
26+
from pandas.core.arraylike import OpsMixin
2627
from pandas.core.arrays.arrow import ArrowExtensionArray as _ArrowExtensionArray
2728
from pandas.core.construction import extract_array
2829

@@ -71,7 +72,7 @@ def construct_array_type(cls) -> type_t[ArrowStringArray]:
7172
return ArrowStringArray
7273

7374

74-
class ArrowExtensionArray(_ArrowExtensionArray):
75+
class ArrowExtensionArray(OpsMixin, _ArrowExtensionArray):
7576
_data: pa.ChunkedArray
7677

7778
@classmethod

pandas/tests/extension/test_arrow.py

-42
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,6 @@ def data_missing(data):
9393
return type(data)._from_sequence([None, data[0]])
9494

9595

96-
@pytest.fixture(params=["data", "data_missing"])
97-
def all_data(request, data, data_missing):
98-
"""Parametrized fixture returning 'data' or 'data_missing' integer arrays.
99-
100-
Used to test dtype conversion with and without missing values.
101-
"""
102-
if request.param == "data":
103-
return data
104-
elif request.param == "data_missing":
105-
return data_missing
106-
107-
10896
@pytest.fixture
10997
def na_value():
11098
"""The scalar missing value for this type. Default 'None'"""
@@ -283,36 +271,6 @@ class TestBaseIndex(base.BaseIndexTests):
283271
pass
284272

285273

286-
class TestBaseInterface(base.BaseInterfaceTests):
287-
def test_contains(self, data, data_missing, request):
288-
tz = getattr(data.dtype.pyarrow_dtype, "tz", None)
289-
unit = getattr(data.dtype.pyarrow_dtype, "unit", None)
290-
if pa_version_under2p0 and tz not in (None, "UTC") and unit == "us":
291-
request.node.add_marker(
292-
pytest.mark.xfail(
293-
reason=(
294-
f"Not supported by pyarrow < 2.0 "
295-
f"with timestamp type {tz} and {unit}"
296-
)
297-
)
298-
)
299-
super().test_contains(data, data_missing)
300-
301-
@pytest.mark.xfail(reason="pyarrow.ChunkedArray does not support views.")
302-
def test_view(self, data):
303-
super().test_view(data)
304-
305-
306-
class TestBaseMissing(base.BaseMissingTests):
307-
pass
308-
309-
310-
class TestBaseSetitemTests(base.BaseSetitemTests):
311-
@pytest.mark.xfail(reason="GH 45419: pyarrow.ChunkedArray does not support views")
312-
def test_setitem_preserves_views(self, data):
313-
super().test_setitem_preserves_views(data)
314-
315-
316274
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
317275
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
318276
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")

0 commit comments

Comments
 (0)