Skip to content

REF: dont pass exception to check_opname #54365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 4, 2023
39 changes: 33 additions & 6 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,32 @@


class BaseOpsUtil(BaseExtensionTests):
series_scalar_exc: type[Exception] | None = TypeError
frame_scalar_exc: type[Exception] | None = TypeError
series_array_exc: type[Exception] | None = TypeError

def get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could follow up by allowing authors to specify the exception message here

# Find the Exception, if any we expect to raise calling
# obj.__op_name__(other)

# The self.foo_bar_exc pattern isn't great in part because it can depend
# on op_name or dtypes, but we use it here for backward-compatibility.
if op_name in ["divmod", "rdivmod"]:
return self.divmod_exc
if isinstance(obj, pd.Series) and isinstance(other, pd.Series):
return self.series_array_exc
elif isinstance(obj, pd.Series):
return self.series_scalar_exc
else:
return self.frame_scalar_exc

def get_op_from_name(self, op_name: str):
return tm.get_op_from_name(op_name)

def check_opname(self, ser: pd.Series, op_name: str, other, exc=Exception):
def check_opname(self, ser: pd.Series, op_name: str, other):
exc = self.get_expected_exception(op_name, ser, other)
op = self.get_op_from_name(op_name)

self._check_op(ser, op, other, op_name, exc)
Expand All @@ -30,6 +52,9 @@ def _combine(self, obj, other, op):
def _check_op(
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
):
# Check that the Series/DataFrame arithmetic/comparison method matches
# the pointwise result from _combine.

if exc is None:
result = op(ser, other)
expected = self._combine(ser, other, op)
Expand All @@ -41,6 +66,10 @@ def _check_op(

def _check_divmod_op(self, ser: pd.Series, op, other, exc=Exception):
# divmod has multiple return values, so check separately
if op is divmod:
exc = self.get_expected_exception("divmod", ser, other)
else:
exc = self.get_expected_exception("rdivmod", ser, other)
if exc is None:
result_div, result_mod = op(ser, other)
if op is divmod:
Expand Down Expand Up @@ -76,21 +105,19 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
# series & scalar
op_name = all_arithmetic_operators
ser = pd.Series(data)
self.check_opname(ser, op_name, ser.iloc[0], exc=self.series_scalar_exc)
self.check_opname(ser, op_name, ser.iloc[0])

def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
# frame & scalar
op_name = all_arithmetic_operators
df = pd.DataFrame({"A": data})
self.check_opname(df, op_name, data[0], exc=self.frame_scalar_exc)
self.check_opname(df, op_name, data[0])

def test_arith_series_with_array(self, data, all_arithmetic_operators):
# ndarray & other series
op_name = all_arithmetic_operators
ser = pd.Series(data)
self.check_opname(
ser, op_name, pd.Series([ser.iloc[0]] * len(ser)), exc=self.series_array_exc
)
self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser)))

def test_divmod(self, data):
ser = pd.Series(data)
Expand Down
10 changes: 8 additions & 2 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,14 @@ def test_astype_dispatches(frame):


class TestArithmeticOps(base.BaseArithmeticOpsTests):
def check_opname(self, s, op_name, other, exc=None):
super().check_opname(s, op_name, other, exc=None)
series_scalar_exc = None
frame_scalar_exc = None
series_array_exc = None

def get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
return None

def test_arith_series_with_array(self, data, all_arithmetic_operators):
op_name = all_arithmetic_operators
Expand Down
30 changes: 13 additions & 17 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,16 +955,24 @@ def _is_temporal_supported(self, opname, pa_dtype):
and pa.types.is_temporal(pa_dtype)
)

def _get_scalar_exception(self, opname, pa_dtype):
arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)
if opname in {
def get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
if op_name == "divmod" or op_name == "rdivmod":
return self.divmod_exc

dtype = tm.get_dtype(obj)
pa_dtype = dtype.pyarrow_dtype

arrow_temporal_supported = self._is_temporal_supported(op_name, pa_dtype)
if op_name in {
"__mod__",
"__rmod__",
}:
exc = NotImplementedError
elif arrow_temporal_supported:
exc = None
elif opname in ["__add__", "__radd__"] and (
elif op_name in ["__add__", "__radd__"] and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
exc = None
Expand Down Expand Up @@ -1053,10 +1061,6 @@ def test_arith_series_with_scalar(
):
pytest.skip("Skip testing Python string formatting")

self.series_scalar_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
if mark is not None:
request.node.add_marker(mark)
Expand Down Expand Up @@ -1086,10 +1090,6 @@ def test_arith_frame_with_scalar(
):
pytest.skip("Skip testing Python string formatting")

self.frame_scalar_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
if mark is not None:
request.node.add_marker(mark)
Expand All @@ -1114,10 +1114,6 @@ def test_arith_series_with_array(
):
pa_dtype = data.dtype.pyarrow_dtype

self.series_array_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

if (
all_arithmetic_operators
in (
Expand Down Expand Up @@ -1159,7 +1155,7 @@ def test_arith_series_with_array(
or pa.types.is_decimal(pa_dtype)
):
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
self.check_opname(ser, op_name, other, exc=self.series_array_exc)
self.check_opname(ser, op_name, other)

def test_add_series_with_extension_array(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
Expand Down
19 changes: 4 additions & 15 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,11 @@ class TestMissing(base.BaseMissingTests):
class TestArithmeticOps(base.BaseArithmeticOpsTests):
implements = {"__sub__", "__rsub__"}

def check_opname(self, s, op_name, other, exc=None):
# overwriting to indicate ops don't raise an error
exc = None
def get_expected_exception(self, op_name, obj, other):
if op_name.strip("_").lstrip("r") in ["pow", "truediv", "floordiv"]:
# match behavior with non-masked bool dtype
exc = NotImplementedError
super().check_opname(s, op_name, other, exc=exc)
return NotImplementedError
return None

def _check_op(self, obj, op, other, op_name, exc=NotImplementedError):
if exc is None:
Expand Down Expand Up @@ -168,18 +166,9 @@ def _check_op(self, obj, op, other, op_name, exc=NotImplementedError):
def test_divmod_series_array(self, data, data_for_twos):
super().test_divmod_series_array(data, data_for_twos)

@pytest.mark.xfail(
reason="Inconsistency between floordiv and divmod; we raise for floordiv "
"but not for divmod. This matches what we do for non-masked bool dtype."
)
def test_divmod(self, data):
super().test_divmod(data)


class TestComparisonOps(base.BaseComparisonOpsTests):
def check_opname(self, s, op_name, other, exc=None):
# overwriting to indicate ops don't raise an error
super().check_opname(s, op_name, other, exc=None)
pass


class TestReshaping(base.BaseReshapingTests):
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,6 @@ def test_divmod_series_array(self):
# skipping because it is not implemented
pass

def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
return super()._check_divmod_op(s, op, other, exc=TypeError)


class TestComparisonOps(base.BaseComparisonOpsTests):
def _compare_other(self, s, data, op, other):
Expand Down
28 changes: 4 additions & 24 deletions pandas/tests/extension/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,10 @@ class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests):
class TestArithmeticOps(BaseDatetimeTests, base.BaseArithmeticOpsTests):
implements = {"__sub__", "__rsub__"}

def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
# frame & scalar
if all_arithmetic_operators in self.implements:
df = pd.DataFrame({"A": data})
self.check_opname(df, all_arithmetic_operators, data[0], exc=None)
else:
# ... but not the rest.
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)

def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
if all_arithmetic_operators in self.implements:
ser = pd.Series(data)
self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None)
else:
# ... but not the rest.
super().test_arith_series_with_scalar(data, all_arithmetic_operators)
def get_expected_exception(self, op_name, obj, other):
if op_name in self.implements:
return None
return super().get_expected_exception(op_name, obj, other)

def test_add_series_with_extension_array(self, data):
# Datetime + Datetime not implemented
Expand All @@ -154,14 +142,6 @@ def test_add_series_with_extension_array(self, data):
with pytest.raises(TypeError, match=msg):
ser + data

def test_arith_series_with_array(self, data, all_arithmetic_operators):
if all_arithmetic_operators in self.implements:
ser = pd.Series(data)
self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None)
else:
# ... but not the rest.
super().test_arith_series_with_scalar(data, all_arithmetic_operators)

def test_divmod_series_array(self):
# GH 23287
# skipping because it is not implemented
Expand Down
17 changes: 8 additions & 9 deletions pandas/tests/extension/test_masked_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,17 @@ def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
with pytest.raises(exc):
op(s, other)

def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
# overwriting to indicate ops don't raise an error
super().check_opname(ser, op_name, other, exc=None)

def _check_divmod_op(self, ser: pd.Series, op, other, exc=None):
super()._check_divmod_op(ser, op, other, None)
series_scalar_exc = None
series_array_exc = None
frame_scalar_exc = None
divmod_exc = None


class TestComparisonOps(base.BaseComparisonOpsTests):
series_scalar_exc = None
series_array_exc = None
frame_scalar_exc = None

def _check_op(
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
):
Expand All @@ -207,9 +209,6 @@ def _check_op(
with pytest.raises(exc):
op(ser, other)

def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
super().check_opname(ser, op_name, other, exc=None)

def _compare_other(self, ser: pd.Series, data, op, other):
op_name = f"__{op.__name__}__"
self.check_opname(ser, op_name, other)
Expand Down
29 changes: 4 additions & 25 deletions pandas/tests/extension/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,10 @@ class TestInterface(BasePeriodTests, base.BaseInterfaceTests):
class TestArithmeticOps(BasePeriodTests, base.BaseArithmeticOpsTests):
implements = {"__sub__", "__rsub__"}

def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
# frame & scalar
if all_arithmetic_operators in self.implements:
df = pd.DataFrame({"A": data})
self.check_opname(df, all_arithmetic_operators, data[0], exc=None)
else:
# ... but not the rest.
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)

def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
# we implement substitution...
if all_arithmetic_operators in self.implements:
s = pd.Series(data)
self.check_opname(s, all_arithmetic_operators, s.iloc[0], exc=None)
else:
# ... but not the rest.
super().test_arith_series_with_scalar(data, all_arithmetic_operators)

def test_arith_series_with_array(self, data, all_arithmetic_operators):
if all_arithmetic_operators in self.implements:
s = pd.Series(data)
self.check_opname(s, all_arithmetic_operators, s.iloc[0], exc=None)
else:
# ... but not the rest.
super().test_arith_series_with_scalar(data, all_arithmetic_operators)
def get_expected_exception(self, op_name, obj, other):
if op_name in self.implements:
return None
return super().get_expected_exception(op_name, obj, other)

def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
super()._check_divmod_op(s, op, other, exc=TypeError)
Expand Down