diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index 658018a7ac740..1160c781b4827 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -195,6 +195,7 @@ def _compare_other(self, ser: pd.Series, data, op, other): # comparison should match point-wise comparisons result = op(ser, other) expected = ser.combine(other, op) + expected = self._cast_pointwise_result(op.__name__, ser, other, expected) tm.assert_series_equal(result, expected) else: @@ -207,6 +208,9 @@ def _compare_other(self, ser: pd.Series, data, op, other): if exc is None: # Didn't error, then should match pointwise behavior expected = ser.combine(other, op) + expected = self._cast_pointwise_result( + op.__name__, ser, other, expected + ) tm.assert_series_equal(result, expected) else: with pytest.raises(type(exc)): @@ -218,7 +222,7 @@ def test_compare_scalar(self, data, comparison_op): def test_compare_array(self, data, comparison_op): ser = pd.Series(data) - other = pd.Series([data[0]] * len(data)) + other = pd.Series([data[0]] * len(data), dtype=data.dtype) self._compare_other(ser, data, comparison_op, other) diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index d5002a8fb91bf..65a417b46686b 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -320,7 +320,11 @@ def test_add_series_with_extension_array(self, data): class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests): - pass + def test_compare_array(self, data, comparison_op, request): + if comparison_op.__name__ in ["eq", "ne"]: + mark = pytest.mark.xfail(reason="Comparison methods not implemented") + request.node.add_marker(mark) + super().test_compare_array(data, comparison_op) class TestPrinting(BaseJSON, base.BasePrintingTests): diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 3b54d8e948b14..a5fbfa4524964 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1177,19 +1177,7 @@ def test_compare_array(self, data, comparison_op, na_value): tm.assert_series_equal(result, expected) else: - exc = None - try: - result = comparison_op(ser, other) - except Exception as err: - exc = err - - if exc is None: - # Didn't error, then should match point-wise behavior - expected = ser.combine(other, comparison_op) - tm.assert_series_equal(result, expected) - else: - with pytest.raises(type(exc)): - ser.combine(other, comparison_op) + return super().test_compare_array(data, comparison_op) def test_invalid_other_comp(self, data, comparison_op): # GH 48833 diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index 828cdfa538ba5..63700e8a27930 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -272,20 +272,12 @@ def test_add_series_with_extension_array(self, data): class TestComparisonOps(base.BaseComparisonOpsTests): def _compare_other(self, s, data, op, other): op_name = f"__{op.__name__}__" - if op_name == "__eq__": - result = op(s, other) - expected = s.combine(other, lambda x, y: x == y) - assert (result == expected).all() - - elif op_name == "__ne__": - result = op(s, other) - expected = s.combine(other, lambda x, y: x != y) - assert (result == expected).all() - - else: + if op_name not in ["__eq__", "__ne__"]: msg = "Unordered Categoricals can only compare equality or not" with pytest.raises(TypeError, match=msg): op(data, other) + else: + return super()._compare_other(s, data, op, other) @pytest.mark.parametrize( "categories", diff --git a/pandas/tests/extension/test_masked_numeric.py b/pandas/tests/extension/test_masked_numeric.py index e949d444222aa..d2b8e813a16ae 100644 --- a/pandas/tests/extension/test_masked_numeric.py +++ b/pandas/tests/extension/test_masked_numeric.py @@ -177,10 +177,6 @@ class TestComparisonOps(base.BaseComparisonOpsTests): def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): return pointwise_result.astype("boolean") - def _compare_other(self, ser: pd.Series, data, op, other): - op_name = f"__{op.__name__}__" - self.check_opname(ser, op_name, other) - class TestInterface(base.BaseInterfaceTests): pass diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 7256ea5837bbf..b439a0b1131a5 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -190,12 +190,15 @@ class TestCasting(base.BaseCastingTests): class TestComparisonOps(base.BaseComparisonOpsTests): - def _compare_other(self, ser, data, op, other): - op_name = f"__{op.__name__}__" - result = getattr(ser, op_name)(other) - dtype = "boolean[pyarrow]" if ser.dtype.storage == "pyarrow" else "boolean" - expected = getattr(ser.astype(object), op_name)(other).astype(dtype) - tm.assert_series_equal(result, expected) + def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): + dtype = tm.get_dtype(obj) + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no + # attribute "storage" + if dtype.storage == "pyarrow": # type: ignore[union-attr] + cast_to = "boolean[pyarrow]" + else: + cast_to = "boolean" + return pointwise_result.astype(cast_to) def test_compare_scalar(self, data, comparison_op): ser = pd.Series(data)