Skip to content

Commit 23b3ac8

Browse files
authored
REF: de-duplicate test_direct_arith_with_ndframe_returns_not_implemented (#54410)
* REF: de-duplicate test_direct_arith_with_ndframe_returns_not_implemented * parametrize
1 parent 39f9b33 commit 23b3ac8

File tree

5 files changed

+54
-74
lines changed

5 files changed

+54
-74
lines changed

pandas/_testing/__init__.py

+20
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,26 @@
269269
EMPTY_STRING_PATTERN = re.compile("^$")
270270

271271

272+
arithmetic_dunder_methods = [
273+
"__add__",
274+
"__radd__",
275+
"__sub__",
276+
"__rsub__",
277+
"__mul__",
278+
"__rmul__",
279+
"__floordiv__",
280+
"__rfloordiv__",
281+
"__truediv__",
282+
"__rtruediv__",
283+
"__pow__",
284+
"__rpow__",
285+
"__mod__",
286+
"__rmod__",
287+
]
288+
289+
comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"]
290+
291+
272292
def reset_display_options() -> None:
273293
"""
274294
Reset the display options for printing and representing objects.

pandas/conftest.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -974,25 +974,9 @@ def ea_scalar_and_dtype(request):
974974
# ----------------------------------------------------------------
975975
# Operators & Operations
976976
# ----------------------------------------------------------------
977-
_all_arithmetic_operators = [
978-
"__add__",
979-
"__radd__",
980-
"__sub__",
981-
"__rsub__",
982-
"__mul__",
983-
"__rmul__",
984-
"__floordiv__",
985-
"__rfloordiv__",
986-
"__truediv__",
987-
"__rtruediv__",
988-
"__pow__",
989-
"__rpow__",
990-
"__mod__",
991-
"__rmod__",
992-
]
993977

994978

995-
@pytest.fixture(params=_all_arithmetic_operators)
979+
@pytest.fixture(params=tm.arithmetic_dunder_methods)
996980
def all_arithmetic_operators(request):
997981
"""
998982
Fixture for dunder names for common arithmetic operations.

pandas/core/arrays/datetimelike.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@
157157
DTScalarOrNaT = Union[DatetimeLikeScalar, NaTType]
158158

159159

160+
def _make_unpacked_invalid_op(op_name: str):
161+
op = make_invalid_op(op_name)
162+
return unpack_zerodim_and_defer(op_name)(op)
163+
164+
160165
def _period_dispatch(meth: F) -> F:
161166
"""
162167
For PeriodArray methods, dispatch to DatetimeArray and re-wrap the results
@@ -979,18 +984,18 @@ def _cmp_method(self, other, op):
979984

980985
# pow is invalid for all three subclasses; TimedeltaArray will override
981986
# the multiplication and division ops
982-
__pow__ = make_invalid_op("__pow__")
983-
__rpow__ = make_invalid_op("__rpow__")
984-
__mul__ = make_invalid_op("__mul__")
985-
__rmul__ = make_invalid_op("__rmul__")
986-
__truediv__ = make_invalid_op("__truediv__")
987-
__rtruediv__ = make_invalid_op("__rtruediv__")
988-
__floordiv__ = make_invalid_op("__floordiv__")
989-
__rfloordiv__ = make_invalid_op("__rfloordiv__")
990-
__mod__ = make_invalid_op("__mod__")
991-
__rmod__ = make_invalid_op("__rmod__")
992-
__divmod__ = make_invalid_op("__divmod__")
993-
__rdivmod__ = make_invalid_op("__rdivmod__")
987+
__pow__ = _make_unpacked_invalid_op("__pow__")
988+
__rpow__ = _make_unpacked_invalid_op("__rpow__")
989+
__mul__ = _make_unpacked_invalid_op("__mul__")
990+
__rmul__ = _make_unpacked_invalid_op("__rmul__")
991+
__truediv__ = _make_unpacked_invalid_op("__truediv__")
992+
__rtruediv__ = _make_unpacked_invalid_op("__rtruediv__")
993+
__floordiv__ = _make_unpacked_invalid_op("__floordiv__")
994+
__rfloordiv__ = _make_unpacked_invalid_op("__rfloordiv__")
995+
__mod__ = _make_unpacked_invalid_op("__mod__")
996+
__rmod__ = _make_unpacked_invalid_op("__rmod__")
997+
__divmod__ = _make_unpacked_invalid_op("__divmod__")
998+
__rdivmod__ = _make_unpacked_invalid_op("__rdivmod__")
994999

9951000
@final
9961001
def _get_i8_values_and_mask(

pandas/tests/extension/base/ops.py

+16-34
Original file line numberDiff line numberDiff line change
@@ -166,23 +166,25 @@ def test_add_series_with_extension_array(self, data):
166166
expected = pd.Series(data + data)
167167
tm.assert_series_equal(result, expected)
168168

169-
@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
169+
@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame, pd.Index])
170+
@pytest.mark.parametrize(
171+
"op_name",
172+
[
173+
x
174+
for x in tm.arithmetic_dunder_methods + tm.comparison_dunder_methods
175+
if not x.startswith("__r")
176+
],
177+
)
170178
def test_direct_arith_with_ndframe_returns_not_implemented(
171-
self, request, data, box
179+
self, data, box, op_name
172180
):
173-
# EAs should return NotImplemented for ops with Series/DataFrame
181+
# EAs should return NotImplemented for ops with Series/DataFrame/Index
174182
# Pandas takes care of unboxing the series and calling the EA's op.
175-
other = pd.Series(data)
176-
if box is pd.DataFrame:
177-
other = other.to_frame()
178-
if not hasattr(data, "__add__"):
179-
request.node.add_marker(
180-
pytest.mark.xfail(
181-
reason=f"{type(data).__name__} does not implement add"
182-
)
183-
)
184-
result = data.__add__(other)
185-
assert result is NotImplemented
183+
other = box(data)
184+
185+
if hasattr(data, op_name):
186+
result = getattr(data, op_name)(other)
187+
assert result is NotImplemented
186188

187189

188190
class BaseComparisonOpsTests(BaseOpsUtil):
@@ -219,26 +221,6 @@ def test_compare_array(self, data, comparison_op):
219221
other = pd.Series([data[0]] * len(data))
220222
self._compare_other(ser, data, comparison_op, other)
221223

222-
@pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
223-
def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
224-
# EAs should return NotImplemented for ops with Series/DataFrame
225-
# Pandas takes care of unboxing the series and calling the EA's op.
226-
other = pd.Series(data)
227-
if box is pd.DataFrame:
228-
other = other.to_frame()
229-
230-
if hasattr(data, "__eq__"):
231-
result = data.__eq__(other)
232-
assert result is NotImplemented
233-
else:
234-
pytest.skip(f"{type(data).__name__} does not implement __eq__")
235-
236-
if hasattr(data, "__ne__"):
237-
result = data.__ne__(other)
238-
assert result is NotImplemented
239-
else:
240-
pytest.skip(f"{type(data).__name__} does not implement __ne__")
241-
242224

243225
class BaseUnaryOpsTests(BaseOpsUtil):
244226
def test_invert(self, data):

pandas/tests/extension/test_period.py

-11
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,6 @@ def test_add_series_with_extension_array(self, data):
127127
with pytest.raises(TypeError, match=msg):
128128
s + data
129129

130-
def test_direct_arith_with_ndframe_returns_not_implemented(
131-
self, data, frame_or_series
132-
):
133-
# Override to use __sub__ instead of __add__
134-
other = pd.Series(data)
135-
if frame_or_series is pd.DataFrame:
136-
other = other.to_frame()
137-
138-
result = data.__sub__(other)
139-
assert result is NotImplemented
140-
141130

142131
class TestCasting(BasePeriodTests, base.BaseCastingTests):
143132
pass

0 commit comments

Comments
 (0)