Skip to content

Commit a02e3ee

Browse files
jbrockmendelmroeschke
authored andcommitted
REF: de-duplicate _compare_other (pandas-dev#54424)
* REF: de-duplicate _compare_other * mypy fixup
1 parent 079ed2e commit a02e3ee

File tree

6 files changed

+23
-36
lines changed

6 files changed

+23
-36
lines changed

pandas/tests/extension/base/ops.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def _compare_other(self, ser: pd.Series, data, op, other):
206206
# comparison should match point-wise comparisons
207207
result = op(ser, other)
208208
expected = ser.combine(other, op)
209+
expected = self._cast_pointwise_result(op.__name__, ser, other, expected)
209210
tm.assert_series_equal(result, expected)
210211

211212
else:
@@ -218,6 +219,9 @@ def _compare_other(self, ser: pd.Series, data, op, other):
218219
if exc is None:
219220
# Didn't error, then should match pointwise behavior
220221
expected = ser.combine(other, op)
222+
expected = self._cast_pointwise_result(
223+
op.__name__, ser, other, expected
224+
)
221225
tm.assert_series_equal(result, expected)
222226
else:
223227
with pytest.raises(type(exc)):
@@ -229,7 +233,7 @@ def test_compare_scalar(self, data, comparison_op):
229233

230234
def test_compare_array(self, data, comparison_op):
231235
ser = pd.Series(data)
232-
other = pd.Series([data[0]] * len(data))
236+
other = pd.Series([data[0]] * len(data), dtype=data.dtype)
233237
self._compare_other(ser, data, comparison_op, other)
234238

235239

pandas/tests/extension/json/test_json.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,11 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
315315

316316

317317
class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests):
318-
pass
318+
def test_compare_array(self, data, comparison_op, request):
319+
if comparison_op.__name__ in ["eq", "ne"]:
320+
mark = pytest.mark.xfail(reason="Comparison methods not implemented")
321+
request.node.add_marker(mark)
322+
super().test_compare_array(data, comparison_op)
319323

320324

321325
class TestPrinting(BaseJSON, base.BasePrintingTests):

pandas/tests/extension/test_arrow.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -1155,19 +1155,7 @@ def test_compare_array(self, data, comparison_op, na_value):
11551155
tm.assert_series_equal(result, expected)
11561156

11571157
else:
1158-
exc = None
1159-
try:
1160-
result = comparison_op(ser, other)
1161-
except Exception as err:
1162-
exc = err
1163-
1164-
if exc is None:
1165-
# Didn't error, then should match point-wise behavior
1166-
expected = ser.combine(other, comparison_op)
1167-
tm.assert_series_equal(result, expected)
1168-
else:
1169-
with pytest.raises(type(exc)):
1170-
ser.combine(other, comparison_op)
1158+
return super().test_compare_array(data, comparison_op)
11711159

11721160
def test_invalid_other_comp(self, data, comparison_op):
11731161
# GH 48833

pandas/tests/extension/test_categorical.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -267,20 +267,12 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request)
267267
class TestComparisonOps(base.BaseComparisonOpsTests):
268268
def _compare_other(self, s, data, op, other):
269269
op_name = f"__{op.__name__}__"
270-
if op_name == "__eq__":
271-
result = op(s, other)
272-
expected = s.combine(other, lambda x, y: x == y)
273-
assert (result == expected).all()
274-
275-
elif op_name == "__ne__":
276-
result = op(s, other)
277-
expected = s.combine(other, lambda x, y: x != y)
278-
assert (result == expected).all()
279-
280-
else:
270+
if op_name not in ["__eq__", "__ne__"]:
281271
msg = "Unordered Categoricals can only compare equality or not"
282272
with pytest.raises(TypeError, match=msg):
283273
op(data, other)
274+
else:
275+
return super()._compare_other(s, data, op, other)
284276

285277
@pytest.mark.parametrize(
286278
"categories",

pandas/tests/extension/test_masked_numeric.py

-4
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,6 @@ class TestComparisonOps(base.BaseComparisonOpsTests):
177177
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
178178
return pointwise_result.astype("boolean")
179179

180-
def _compare_other(self, ser: pd.Series, data, op, other):
181-
op_name = f"__{op.__name__}__"
182-
self.check_opname(ser, op_name, other)
183-
184180

185181
class TestInterface(base.BaseInterfaceTests):
186182
pass

pandas/tests/extension/test_string.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,15 @@ class TestCasting(base.BaseCastingTests):
190190

191191

192192
class TestComparisonOps(base.BaseComparisonOpsTests):
193-
def _compare_other(self, ser, data, op, other):
194-
op_name = f"__{op.__name__}__"
195-
result = getattr(ser, op_name)(other)
196-
dtype = "boolean[pyarrow]" if ser.dtype.storage == "pyarrow" else "boolean"
197-
expected = getattr(ser.astype(object), op_name)(other).astype(dtype)
198-
tm.assert_series_equal(result, expected)
193+
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
194+
dtype = tm.get_dtype(obj)
195+
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
196+
# attribute "storage"
197+
if dtype.storage == "pyarrow": # type: ignore[union-attr]
198+
cast_to = "boolean[pyarrow]"
199+
else:
200+
cast_to = "boolean"
201+
return pointwise_result.astype(cast_to)
199202

200203
def test_compare_scalar(self, data, comparison_op):
201204
ser = pd.Series(data)

0 commit comments

Comments
 (0)