diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 9295cf7873d98..e3433ffcb24e8 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -33,7 +33,6 @@ infer_dtype_from_scalar, ) from pandas.core.dtypes.common import ( - CategoricalDtype, is_array_like, is_bool_dtype, is_float_dtype, @@ -730,9 +729,7 @@ def __setstate__(self, state) -> None: def _cmp_method(self, other, op) -> ArrowExtensionArray: pc_func = ARROW_CMP_FUNCS[op.__name__] - if isinstance( - other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray) - ) or isinstance(getattr(other, "dtype", None), CategoricalDtype): + if isinstance(other, (ExtensionArray, np.ndarray, list)): try: result = pc_func(self._pa_array, self._box_pa(other)) except pa.ArrowNotImplementedError: diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 7227ea77ca433..1a0a07f3a686b 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -1018,7 +1018,30 @@ def searchsorted( return super().searchsorted(value=value, side=side, sorter=sorter) def _cmp_method(self, other, op): - from pandas.arrays import BooleanArray + from pandas.arrays import ( + ArrowExtensionArray, + BooleanArray, + ) + + if ( + isinstance(other, BaseStringArray) + and self.dtype.na_value is not libmissing.NA + and other.dtype.na_value is libmissing.NA + ): + # NA has priority of NaN semantics + return NotImplemented + + if isinstance(other, ArrowExtensionArray): + if isinstance(other, BaseStringArray): + # pyarrow storage has priority over python storage + # (except if we have NA semantics and other not) + if not ( + self.dtype.na_value is libmissing.NA + and other.dtype.na_value is not libmissing.NA + ): + return NotImplemented + else: + return NotImplemented if isinstance(other, StringArray): other = other._ndarray diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index d35083fd892a8..dc7343d0ea616 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -473,6 +473,14 @@ def value_counts(self, dropna: bool = True) -> Series: return result def _cmp_method(self, other, op): + if ( + isinstance(other, BaseStringArray) + and self.dtype.na_value is not libmissing.NA + and other.dtype.na_value is libmissing.NA + ): + # NA has priority of NaN semantics + return NotImplemented + result = super()._cmp_method(other, op) if self.dtype.na_value is np.nan: if op == operator.ne: diff --git a/pandas/core/ops/invalid.py b/pandas/core/ops/invalid.py index 395db1617cb63..62aa79a881717 100644 --- a/pandas/core/ops/invalid.py +++ b/pandas/core/ops/invalid.py @@ -25,7 +25,7 @@ def invalid_comparison( left: ArrayLike, - right: ArrayLike | Scalar, + right: ArrayLike | list | Scalar, op: Callable[[Any, Any], bool], ) -> npt.NDArray[np.bool_]: """ diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 336a0fef69170..975a539a79724 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -10,6 +10,7 @@ from pandas._config import using_string_dtype +from pandas.compat import HAS_PYARROW from pandas.compat.pyarrow import ( pa_version_under12p0, pa_version_under19p0, @@ -45,6 +46,25 @@ def cls(dtype): return dtype.construct_array_type() +def string_dtype_highest_priority(dtype1, dtype2): + if HAS_PYARROW: + DTYPE_HIERARCHY = [ + pd.StringDtype("python", na_value=np.nan), + pd.StringDtype("pyarrow", na_value=np.nan), + pd.StringDtype("python", na_value=pd.NA), + pd.StringDtype("pyarrow", na_value=pd.NA), + ] + else: + DTYPE_HIERARCHY = [ + pd.StringDtype("python", na_value=np.nan), + pd.StringDtype("python", na_value=pd.NA), + ] + + h1 = DTYPE_HIERARCHY.index(dtype1) + h2 = DTYPE_HIERARCHY.index(dtype2) + return DTYPE_HIERARCHY[max(h1, h2)] + + def test_dtype_constructor(): pytest.importorskip("pyarrow") @@ -319,13 +339,18 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype): tm.assert_extension_array_equal(result, expected) -def test_comparison_methods_array(comparison_op, dtype): +def test_comparison_methods_array(comparison_op, dtype, dtype2): op_name = f"__{comparison_op.__name__}__" a = pd.array(["a", None, "c"], dtype=dtype) - other = [None, None, "c"] - result = getattr(a, op_name)(other) - if dtype.na_value is np.nan: + other = pd.array([None, None, "c"], dtype=dtype2) + result = comparison_op(a, other) + + # ensure operation is commutative + result2 = comparison_op(other, a) + tm.assert_equal(result, result2) + + if dtype.na_value is np.nan and dtype2.na_value is np.nan: if operator.ne == comparison_op: expected = np.array([True, True, False]) else: @@ -333,11 +358,36 @@ def test_comparison_methods_array(comparison_op, dtype): expected[-1] = getattr(other[-1], op_name)(a[-1]) tm.assert_numpy_array_equal(result, expected) - result = getattr(a, op_name)(pd.NA) + else: + max_dtype = string_dtype_highest_priority(dtype, dtype2) + if max_dtype.storage == "python": + expected_dtype = "boolean" + else: + expected_dtype = "bool[pyarrow]" + + expected = np.full(len(a), fill_value=None, dtype="object") + expected[-1] = getattr(other[-1], op_name)(a[-1]) + expected = pd.array(expected, dtype=expected_dtype) + tm.assert_extension_array_equal(result, expected) + + +def test_comparison_methods_list(comparison_op, dtype): + op_name = f"__{comparison_op.__name__}__" + + a = pd.array(["a", None, "c"], dtype=dtype) + other = [None, None, "c"] + result = comparison_op(a, other) + + # ensure operation is commutative + result2 = comparison_op(other, a) + tm.assert_equal(result, result2) + + if dtype.na_value is np.nan: if operator.ne == comparison_op: - expected = np.array([True, True, True]) + expected = np.array([True, True, False]) else: expected = np.array([False, False, False]) + expected[-1] = getattr(other[-1], op_name)(a[-1]) tm.assert_numpy_array_equal(result, expected) else: @@ -347,10 +397,6 @@ def test_comparison_methods_array(comparison_op, dtype): expected = pd.array(expected, dtype=expected_dtype) tm.assert_extension_array_equal(result, expected) - result = getattr(a, op_name)(pd.NA) - expected = pd.array([None, None, None], dtype=expected_dtype) - tm.assert_extension_array_equal(result, expected) - def test_constructor_raises(cls): if cls is pd.arrays.StringArray: diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 25129111180d6..6ea8ac59ca3e6 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -31,6 +31,7 @@ from pandas.api.types import is_string_dtype from pandas.core.arrays import ArrowStringArray from pandas.core.arrays.string_ import StringDtype +from pandas.tests.arrays.string_.test_string import string_dtype_highest_priority from pandas.tests.extension import base @@ -202,10 +203,13 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): dtype = cast(StringDtype, tm.get_dtype(obj)) if op_name in ["__add__", "__radd__"]: cast_to = dtype + dtype_other = tm.get_dtype(other) if not isinstance(other, str) else None + if isinstance(dtype_other, StringDtype): + cast_to = string_dtype_highest_priority(dtype, dtype_other) elif dtype.na_value is np.nan: cast_to = np.bool_ # type: ignore[assignment] elif dtype.storage == "pyarrow": - cast_to = "boolean[pyarrow]" # type: ignore[assignment] + cast_to = "bool[pyarrow]" # type: ignore[assignment] else: cast_to = "boolean" # type: ignore[assignment] return pointwise_result.astype(cast_to) @@ -237,9 +241,11 @@ def test_arith_series_with_array( using_infer_string and all_arithmetic_operators == "__radd__" and ( - (dtype.na_value is pd.NA) or (dtype.storage == "python" and HAS_PYARROW) + dtype.na_value is pd.NA + and not (not HAS_PYARROW and dtype.storage == "python") ) ): + # TODO(infer_string) mark = pytest.mark.xfail( reason="The pointwise operation result will be inferred to " "string[nan, pyarrow], which does not match the input dtype"