diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 1f35013075751..b2a8ec6bf62e8 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -31,6 +31,7 @@ ) from pandas.core.dtypes.missing import isna +from pandas.core.arraylike import OpsMixin from pandas.core.arrays.base import ExtensionArray from pandas.core.indexers import ( check_array_indexer, @@ -45,13 +46,22 @@ from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning from pandas.core.arrays.arrow.dtype import ArrowDtype + ARROW_CMP_FUNCS = { + "eq": pc.equal, + "ne": pc.not_equal, + "lt": pc.less, + "gt": pc.greater, + "le": pc.less_equal, + "ge": pc.greater_equal, + } + if TYPE_CHECKING: from pandas import Series ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray") -class ArrowExtensionArray(ExtensionArray): +class ArrowExtensionArray(OpsMixin, ExtensionArray): """ Base class for ExtensionArray backed by Arrow ChunkedArray. """ @@ -179,6 +189,34 @@ def __arrow_array__(self, type=None): """Convert myself to a pyarrow ChunkedArray.""" return self._data + def _cmp_method(self, other, op): + from pandas.arrays import BooleanArray + + pc_func = ARROW_CMP_FUNCS[op.__name__] + if isinstance(other, ArrowExtensionArray): + result = pc_func(self._data, other._data) + elif isinstance(other, (np.ndarray, list)): + result = pc_func(self._data, other) + elif is_scalar(other): + try: + result = pc_func(self._data, pa.scalar(other)) + except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid): + mask = isna(self) | isna(other) + valid = ~mask + result = np.zeros(len(self), dtype="bool") + result[valid] = op(np.array(self)[valid], other) + return BooleanArray(result, mask) + else: + return NotImplementedError( + f"{op.__name__} not implemented for {type(other)}" + ) + + if pa_version_under2p0: + result = result.to_pandas().values + else: + result = result.to_numpy() + return BooleanArray._from_sequence(result) + def equals(self, other) -> bool: if not isinstance(other, ArrowExtensionArray): return False @@ -581,7 +619,7 @@ def _replace_with_indices( # fast path for a contiguous set of indices arrays = [ chunk[:start], - pa.array(value, type=chunk.type), + pa.array(value, type=chunk.type, from_pandas=True), chunk[stop + 1 :], ] arrays = [arr for arr in arrays if len(arr)] diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index a07f748fa0c8c..c4d1a35315d7d 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -34,7 +34,6 @@ ) from pandas.core.dtypes.missing import isna -from pandas.core.arraylike import OpsMixin from pandas.core.arrays.arrow import ArrowExtensionArray from pandas.core.arrays.boolean import BooleanDtype from pandas.core.arrays.integer import Int64Dtype @@ -51,15 +50,6 @@ from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning - ARROW_CMP_FUNCS = { - "eq": pc.equal, - "ne": pc.not_equal, - "lt": pc.less, - "gt": pc.greater, - "le": pc.less_equal, - "ge": pc.greater_equal, - } - ArrowStringScalarOrNAT = Union[str, libmissing.NAType] @@ -74,9 +64,7 @@ def _chk_pyarrow_available() -> None: # fallback for the ones that pyarrow doesn't yet support -class ArrowStringArray( - OpsMixin, ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin -): +class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin): """ Extension array for string data in a ``pyarrow.ChunkedArray``. @@ -190,32 +178,6 @@ def to_numpy( result[mask] = na_value return result - def _cmp_method(self, other, op): - from pandas.arrays import BooleanArray - - pc_func = ARROW_CMP_FUNCS[op.__name__] - if isinstance(other, ArrowStringArray): - result = pc_func(self._data, other._data) - elif isinstance(other, (np.ndarray, list)): - result = pc_func(self._data, other) - elif is_scalar(other): - try: - result = pc_func(self._data, pa.scalar(other)) - except (pa.lib.ArrowNotImplementedError, pa.lib.ArrowInvalid): - mask = isna(self) | isna(other) - valid = ~mask - result = np.zeros(len(self), dtype="bool") - result[valid] = op(np.array(self)[valid], other) - return BooleanArray(result, mask) - else: - return NotImplemented - - if pa_version_under2p0: - result = result.to_pandas().values - else: - result = result.to_numpy() - return BooleanArray._from_sequence(result) - def insert(self, loc: int, item): if not isinstance(item, str) and item is not libmissing.NA: raise TypeError("Scalar must be NA or str") diff --git a/pandas/tests/extension/arrow/arrays.py b/pandas/tests/extension/arrow/arrays.py index 22595c4e461d7..26b94ebe5a8da 100644 --- a/pandas/tests/extension/arrow/arrays.py +++ b/pandas/tests/extension/arrow/arrays.py @@ -23,7 +23,6 @@ take, ) from pandas.api.types import is_scalar -from pandas.core.arraylike import OpsMixin from pandas.core.arrays.arrow import ArrowExtensionArray as _ArrowExtensionArray from pandas.core.construction import extract_array @@ -72,7 +71,7 @@ def construct_array_type(cls) -> type_t[ArrowStringArray]: return ArrowStringArray -class ArrowExtensionArray(OpsMixin, _ArrowExtensionArray): +class ArrowExtensionArray(_ArrowExtensionArray): _data: pa.ChunkedArray @classmethod diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 95cb7045ac68d..9eeaf39959f29 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -93,6 +93,18 @@ def data_missing(data): return type(data)._from_sequence([None, data[0]]) +@pytest.fixture(params=["data", "data_missing"]) +def all_data(request, data, data_missing): + """Parametrized fixture returning 'data' or 'data_missing' integer arrays. + + Used to test dtype conversion with and without missing values. + """ + if request.param == "data": + return data + elif request.param == "data_missing": + return data_missing + + @pytest.fixture def na_value(): """The scalar missing value for this type. Default 'None'""" @@ -271,6 +283,36 @@ class TestBaseIndex(base.BaseIndexTests): pass +class TestBaseInterface(base.BaseInterfaceTests): + def test_contains(self, data, data_missing, request): + tz = getattr(data.dtype.pyarrow_dtype, "tz", None) + unit = getattr(data.dtype.pyarrow_dtype, "unit", None) + if pa_version_under2p0 and tz not in (None, "UTC") and unit == "us": + request.node.add_marker( + pytest.mark.xfail( + reason=( + f"Not supported by pyarrow < 2.0 " + f"with timestamp type {tz} and {unit}" + ) + ) + ) + super().test_contains(data, data_missing) + + @pytest.mark.xfail(reason="pyarrow.ChunkedArray does not support views.") + def test_view(self, data): + super().test_view(data) + + +class TestBaseMissing(base.BaseMissingTests): + pass + + +class TestBaseSetitemTests(base.BaseSetitemTests): + @pytest.mark.xfail(reason="GH 45419: pyarrow.ChunkedArray does not support views") + def test_setitem_preserves_views(self, data): + super().test_setitem_preserves_views(data) + + def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): with pytest.raises(NotImplementedError, match="Passing pyarrow type"): ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")