diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 043f682f7dfa8..c1380fcdbba06 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -31,7 +31,6 @@ ) from pandas.core.dtypes.missing import isna -from pandas.core.arraylike import OpsMixin from pandas.core.arrays.base import ExtensionArray from pandas.core.indexers import ( check_array_indexer, @@ -46,22 +45,13 @@ from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning from pandas.core.arrays.arrow.dtype import ArrowDtype - ARROW_CMP_FUNCS = { - "eq": pc.equal, - "ne": pc.not_equal, - "lt": pc.less, - "gt": pc.greater, - "le": pc.less_equal, - "ge": pc.greater_equal, - } - if TYPE_CHECKING: from pandas import Series ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray") -class ArrowExtensionArray(OpsMixin, ExtensionArray): +class ArrowExtensionArray(ExtensionArray): """ Base class for ExtensionArray backed by Arrow ChunkedArray. """ @@ -189,34 +179,6 @@ def __arrow_array__(self, type=None): """Convert myself to a pyarrow ChunkedArray.""" return self._data - def _cmp_method(self, other, op): - from pandas.arrays import BooleanArray - - pc_func = ARROW_CMP_FUNCS[op.__name__] - if isinstance(other, ArrowExtensionArray): - result = pc_func(self._data, other._data) - elif isinstance(other, (np.ndarray, list)): - result = pc_func(self._data, other) - elif is_scalar(other): - try: - result = pc_func(self._data, pa.scalar(other)) - except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid): - mask = isna(self) | isna(other) - valid = ~mask - result = np.zeros(len(self), dtype="bool") - result[valid] = op(np.array(self)[valid], other) - return BooleanArray(result, mask) - else: - return NotImplementedError( - f"{op.__name__} not implemented for {type(other)}" - ) - - if pa_version_under2p0: - result = result.to_pandas().values - else: - result = result.to_numpy() - return BooleanArray._from_sequence(result) - def equals(self, other) -> bool: if not isinstance(other, ArrowExtensionArray): return False @@ -619,7 +581,7 @@ def _replace_with_indices( # fast path for a contiguous set of indices arrays = [ chunk[:start], - pa.array(value, type=chunk.type, from_pandas=True), + pa.array(value, type=chunk.type), chunk[stop + 1 :], ] arrays = [arr for arr in arrays if len(arr)] diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index c4d1a35315d7d..a07f748fa0c8c 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -34,6 +34,7 @@ ) from pandas.core.dtypes.missing import isna +from pandas.core.arraylike import OpsMixin from pandas.core.arrays.arrow import ArrowExtensionArray from pandas.core.arrays.boolean import BooleanDtype from pandas.core.arrays.integer import Int64Dtype @@ -50,6 +51,15 @@ from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning + ARROW_CMP_FUNCS = { + "eq": pc.equal, + "ne": pc.not_equal, + "lt": pc.less, + "gt": pc.greater, + "le": pc.less_equal, + "ge": pc.greater_equal, + } + ArrowStringScalarOrNAT = Union[str, libmissing.NAType] @@ -64,7 +74,9 @@ def _chk_pyarrow_available() -> None: # fallback for the ones that pyarrow doesn't yet support -class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin): +class ArrowStringArray( + OpsMixin, ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin +): """ Extension array for string data in a ``pyarrow.ChunkedArray``. @@ -178,6 +190,32 @@ def to_numpy( result[mask] = na_value return result + def _cmp_method(self, other, op): + from pandas.arrays import BooleanArray + + pc_func = ARROW_CMP_FUNCS[op.__name__] + if isinstance(other, ArrowStringArray): + result = pc_func(self._data, other._data) + elif isinstance(other, (np.ndarray, list)): + result = pc_func(self._data, other) + elif is_scalar(other): + try: + result = pc_func(self._data, pa.scalar(other)) + except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid): + mask = isna(self) | isna(other) + valid = ~mask + result = np.zeros(len(self), dtype="bool") + result[valid] = op(np.array(self)[valid], other) + return BooleanArray(result, mask) + else: + return NotImplemented + + if pa_version_under2p0: + result = result.to_pandas().values + else: + result = result.to_numpy() + return BooleanArray._from_sequence(result) + def insert(self, loc: int, item): if not isinstance(item, str) and item is not libmissing.NA: raise TypeError("Scalar must be NA or str") diff --git a/pandas/tests/extension/arrow/arrays.py b/pandas/tests/extension/arrow/arrays.py index 26b94ebe5a8da..22595c4e461d7 100644 --- a/pandas/tests/extension/arrow/arrays.py +++ b/pandas/tests/extension/arrow/arrays.py @@ -23,6 +23,7 @@ take, ) from pandas.api.types import is_scalar +from pandas.core.arraylike import OpsMixin from pandas.core.arrays.arrow import ArrowExtensionArray as _ArrowExtensionArray from pandas.core.construction import extract_array @@ -71,7 +72,7 @@ def construct_array_type(cls) -> type_t[ArrowStringArray]: return ArrowStringArray -class ArrowExtensionArray(_ArrowExtensionArray): +class ArrowExtensionArray(OpsMixin, _ArrowExtensionArray): _data: pa.ChunkedArray @classmethod diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 9eeaf39959f29..95cb7045ac68d 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -93,18 +93,6 @@ def data_missing(data): return type(data)._from_sequence([None, data[0]]) -@pytest.fixture(params=["data", "data_missing"]) -def all_data(request, data, data_missing): - """Parametrized fixture returning 'data' or 'data_missing' integer arrays. - - Used to test dtype conversion with and without missing values. - """ - if request.param == "data": - return data - elif request.param == "data_missing": - return data_missing - - @pytest.fixture def na_value(): """The scalar missing value for this type. Default 'None'""" @@ -283,36 +271,6 @@ class TestBaseIndex(base.BaseIndexTests): pass -class TestBaseInterface(base.BaseInterfaceTests): - def test_contains(self, data, data_missing, request): - tz = getattr(data.dtype.pyarrow_dtype, "tz", None) - unit = getattr(data.dtype.pyarrow_dtype, "unit", None) - if pa_version_under2p0 and tz not in (None, "UTC") and unit == "us": - request.node.add_marker( - pytest.mark.xfail( - reason=( - f"Not supported by pyarrow < 2.0 " - f"with timestamp type {tz} and {unit}" - ) - ) - ) - super().test_contains(data, data_missing) - - @pytest.mark.xfail(reason="pyarrow.ChunkedArray does not support views.") - def test_view(self, data): - super().test_view(data) - - -class TestBaseMissing(base.BaseMissingTests): - pass - - -class TestBaseSetitemTests(base.BaseSetitemTests): - @pytest.mark.xfail(reason="GH 45419: pyarrow.ChunkedArray does not support views") - def test_setitem_preserves_views(self, data): - super().test_setitem_preserves_views(data) - - def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): with pytest.raises(NotImplementedError, match="Passing pyarrow type"): ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")