From ca27202639dcfce8baf46083bc74c39827e647e0 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 4 Aug 2023 10:51:59 -0700 Subject: [PATCH 1/2] REF: de-duplicate test_direct_arith_with_ndframe_returns_not_implemented --- pandas/_testing/__init__.py | 20 +++++++++++ pandas/conftest.py | 18 +--------- pandas/core/arrays/datetimelike.py | 29 +++++++++------- pandas/tests/extension/base/ops.py | 49 +++++++-------------------- pandas/tests/extension/test_period.py | 11 ------ 5 files changed, 51 insertions(+), 76 deletions(-) diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index 483c5ad59872f..73835252c0329 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -269,6 +269,26 @@ EMPTY_STRING_PATTERN = re.compile("^$") +arithmetic_dunder_methods = [ + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", + "__floordiv__", + "__rfloordiv__", + "__truediv__", + "__rtruediv__", + "__pow__", + "__rpow__", + "__mod__", + "__rmod__", +] + +comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"] + + def reset_display_options() -> None: """ Reset the display options for printing and representing objects. diff --git a/pandas/conftest.py b/pandas/conftest.py index 1dcf413f2edf6..9cb29903dc156 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -974,25 +974,9 @@ def ea_scalar_and_dtype(request): # ---------------------------------------------------------------- # Operators & Operations # ---------------------------------------------------------------- -_all_arithmetic_operators = [ - "__add__", - "__radd__", - "__sub__", - "__rsub__", - "__mul__", - "__rmul__", - "__floordiv__", - "__rfloordiv__", - "__truediv__", - "__rtruediv__", - "__pow__", - "__rpow__", - "__mod__", - "__rmod__", -] -@pytest.fixture(params=_all_arithmetic_operators) +@pytest.fixture(params=tm.arithmetic_dunder_methods) def all_arithmetic_operators(request): """ Fixture for dunder names for common arithmetic operations. diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 972a8fb800f92..2b43b090a43e0 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -157,6 +157,11 @@ DTScalarOrNaT = Union[DatetimeLikeScalar, NaTType] +def _make_unpacked_invalid_op(op_name: str): + op = make_invalid_op(op_name) + return unpack_zerodim_and_defer(op_name)(op) + + def _period_dispatch(meth: F) -> F: """ For PeriodArray methods, dispatch to DatetimeArray and re-wrap the results @@ -979,18 +984,18 @@ def _cmp_method(self, other, op): # pow is invalid for all three subclasses; TimedeltaArray will override # the multiplication and division ops - __pow__ = make_invalid_op("__pow__") - __rpow__ = make_invalid_op("__rpow__") - __mul__ = make_invalid_op("__mul__") - __rmul__ = make_invalid_op("__rmul__") - __truediv__ = make_invalid_op("__truediv__") - __rtruediv__ = make_invalid_op("__rtruediv__") - __floordiv__ = make_invalid_op("__floordiv__") - __rfloordiv__ = make_invalid_op("__rfloordiv__") - __mod__ = make_invalid_op("__mod__") - __rmod__ = make_invalid_op("__rmod__") - __divmod__ = make_invalid_op("__divmod__") - __rdivmod__ = make_invalid_op("__rdivmod__") + __pow__ = _make_unpacked_invalid_op("__pow__") + __rpow__ = _make_unpacked_invalid_op("__rpow__") + __mul__ = _make_unpacked_invalid_op("__mul__") + __rmul__ = _make_unpacked_invalid_op("__rmul__") + __truediv__ = _make_unpacked_invalid_op("__truediv__") + __rtruediv__ = _make_unpacked_invalid_op("__rtruediv__") + __floordiv__ = _make_unpacked_invalid_op("__floordiv__") + __rfloordiv__ = _make_unpacked_invalid_op("__rfloordiv__") + __mod__ = _make_unpacked_invalid_op("__mod__") + __rmod__ = _make_unpacked_invalid_op("__rmod__") + __divmod__ = _make_unpacked_invalid_op("__divmod__") + __rdivmod__ = _make_unpacked_invalid_op("__rdivmod__") @final def _get_i8_values_and_mask( diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index aafb1900a4236..92c9888718b4d 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -166,23 +166,20 @@ def test_add_series_with_extension_array(self, data): expected = pd.Series(data + data) tm.assert_series_equal(result, expected) - @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame]) - def test_direct_arith_with_ndframe_returns_not_implemented( - self, request, data, box - ): - # EAs should return NotImplemented for ops with Series/DataFrame + @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame, pd.Index]) + def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box): + # EAs should return NotImplemented for ops with Series/DataFrame/Index # Pandas takes care of unboxing the series and calling the EA's op. - other = pd.Series(data) - if box is pd.DataFrame: - other = other.to_frame() - if not hasattr(data, "__add__"): - request.node.add_marker( - pytest.mark.xfail( - reason=f"{type(data).__name__} does not implement add" - ) - ) - result = data.__add__(other) - assert result is NotImplemented + other = box(data) + + op_names = tm.arithmetic_dunder_methods + tm.comparison_dunder_methods + op_names = [x for x in op_names if not x.startswith("__r")] + for op_name in op_names: + # We use a loop here instead of fixture to avoid overhead from + # re-creating 'data' many times. + if hasattr(data, op_name): + result = getattr(data, op_name)(other) + assert result is NotImplemented class BaseComparisonOpsTests(BaseOpsUtil): @@ -219,26 +216,6 @@ def test_compare_array(self, data, comparison_op): other = pd.Series([data[0]] * len(data)) self._compare_other(ser, data, comparison_op, other) - @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame]) - def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box): - # EAs should return NotImplemented for ops with Series/DataFrame - # Pandas takes care of unboxing the series and calling the EA's op. - other = pd.Series(data) - if box is pd.DataFrame: - other = other.to_frame() - - if hasattr(data, "__eq__"): - result = data.__eq__(other) - assert result is NotImplemented - else: - pytest.skip(f"{type(data).__name__} does not implement __eq__") - - if hasattr(data, "__ne__"): - result = data.__ne__(other) - assert result is NotImplemented - else: - pytest.skip(f"{type(data).__name__} does not implement __ne__") - class BaseUnaryOpsTests(BaseOpsUtil): def test_invert(self, data): diff --git a/pandas/tests/extension/test_period.py b/pandas/tests/extension/test_period.py index 7b6bc98ee8c05..1bb53ad01c201 100644 --- a/pandas/tests/extension/test_period.py +++ b/pandas/tests/extension/test_period.py @@ -131,17 +131,6 @@ def test_add_series_with_extension_array(self, data): with pytest.raises(TypeError, match=msg): s + data - def test_direct_arith_with_ndframe_returns_not_implemented( - self, data, frame_or_series - ): - # Override to use __sub__ instead of __add__ - other = pd.Series(data) - if frame_or_series is pd.DataFrame: - other = other.to_frame() - - result = data.__sub__(other) - assert result is NotImplemented - class TestCasting(BasePeriodTests, base.BaseCastingTests): pass From 759d817b903a8b63e2b0e7c18c88861790f57720 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 4 Aug 2023 14:22:42 -0700 Subject: [PATCH 2/2] parametrize --- pandas/tests/extension/base/ops.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index 92c9888718b4d..658018a7ac740 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -167,19 +167,24 @@ def test_add_series_with_extension_array(self, data): tm.assert_series_equal(result, expected) @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame, pd.Index]) - def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box): + @pytest.mark.parametrize( + "op_name", + [ + x + for x in tm.arithmetic_dunder_methods + tm.comparison_dunder_methods + if not x.startswith("__r") + ], + ) + def test_direct_arith_with_ndframe_returns_not_implemented( + self, data, box, op_name + ): # EAs should return NotImplemented for ops with Series/DataFrame/Index # Pandas takes care of unboxing the series and calling the EA's op. other = box(data) - op_names = tm.arithmetic_dunder_methods + tm.comparison_dunder_methods - op_names = [x for x in op_names if not x.startswith("__r")] - for op_name in op_names: - # We use a loop here instead of fixture to avoid overhead from - # re-creating 'data' many times. - if hasattr(data, op_name): - result = getattr(data, op_name)(other) - assert result is NotImplemented + if hasattr(data, op_name): + result = getattr(data, op_name)(other) + assert result is NotImplemented class BaseComparisonOpsTests(BaseOpsUtil):