Skip to content

REF: dont pass exception to check_opname #54365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 4, 2023
65 changes: 49 additions & 16 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@


class BaseOpsUtil(BaseExtensionTests):
series_scalar_exc: type[Exception] | None = TypeError
frame_scalar_exc: type[Exception] | None = TypeError
series_array_exc: type[Exception] | None = TypeError
divmod_exc: type[Exception] | None = TypeError

def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could follow up by allowing authors to specify the exception message here

# Find the Exception, if any we expect to raise calling
# obj.__op_name__(other)

# The self.obj_bar_exc pattern isn't great in part because it can depend
# on op_name or dtypes, but we use it here for backward-compatibility.
if op_name in ["__divmod__", "__rdivmod__"]:
return self.divmod_exc
if isinstance(obj, pd.Series) and isinstance(other, pd.Series):
return self.series_array_exc
elif isinstance(obj, pd.Series):
return self.series_scalar_exc
else:
return self.frame_scalar_exc

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
# In _check_op we check that the result of a pointwise operation
# (found via _combine) matches the result of the vectorized
Expand All @@ -24,17 +46,21 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
def get_op_from_name(self, op_name: str):
return tm.get_op_from_name(op_name)

def check_opname(self, ser: pd.Series, op_name: str, other, exc=Exception):
op = self.get_op_from_name(op_name)

self._check_op(ser, op, other, op_name, exc)

# Subclasses are not expected to need to override _check_op or _combine.
# Subclasses are not expected to need to override check_opname, _check_op,
# _check_divmod_op, or _combine.
# Ideally any relevant overriding can be done in _cast_pointwise_result,
# get_op_from_name, and the specification of `exc`. If you find a use
# case that still requires overriding _check_op or _combine, please let
# us know at github.com/pandas-dev/pandas/issues
@final
def check_opname(self, ser: pd.Series, op_name: str, other):
exc = self._get_expected_exception(op_name, ser, other)
op = self.get_op_from_name(op_name)

self._check_op(ser, op, other, op_name, exc)

# see comment on check_opname
@final
def _combine(self, obj, other, op):
if isinstance(obj, pd.DataFrame):
if len(obj.columns) != 1:
Expand All @@ -44,11 +70,14 @@ def _combine(self, obj, other, op):
expected = obj.combine(other, op)
return expected

# see comment on _combine
# see comment on check_opname
@final
def _check_op(
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
):
# Check that the Series/DataFrame arithmetic/comparison method matches
# the pointwise result from _combine.

if exc is None:
result = op(ser, other)
expected = self._combine(ser, other, op)
Expand All @@ -59,8 +88,14 @@ def _check_op(
with pytest.raises(exc):
op(ser, other)

def _check_divmod_op(self, ser: pd.Series, op, other, exc=Exception):
# divmod has multiple return values, so check separately
# see comment on check_opname
@final
def _check_divmod_op(self, ser: pd.Series, op, other):
# check that divmod behavior matches behavior of floordiv+mod
if op is divmod:
exc = self._get_expected_exception("__divmod__", ser, other)
else:
exc = self._get_expected_exception("__rdivmod__", ser, other)
if exc is None:
result_div, result_mod = op(ser, other)
if op is divmod:
Expand Down Expand Up @@ -96,26 +131,24 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
# series & scalar
op_name = all_arithmetic_operators
ser = pd.Series(data)
self.check_opname(ser, op_name, ser.iloc[0], exc=self.series_scalar_exc)
self.check_opname(ser, op_name, ser.iloc[0])

def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
# frame & scalar
op_name = all_arithmetic_operators
df = pd.DataFrame({"A": data})
self.check_opname(df, op_name, data[0], exc=self.frame_scalar_exc)
self.check_opname(df, op_name, data[0])

def test_arith_series_with_array(self, data, all_arithmetic_operators):
# ndarray & other series
op_name = all_arithmetic_operators
ser = pd.Series(data)
self.check_opname(
ser, op_name, pd.Series([ser.iloc[0]] * len(ser)), exc=self.series_array_exc
)
self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser)))

def test_divmod(self, data):
ser = pd.Series(data)
self._check_divmod_op(ser, divmod, 1, exc=self.divmod_exc)
self._check_divmod_op(1, ops.rdivmod, ser, exc=self.divmod_exc)
self._check_divmod_op(ser, divmod, 1)
self._check_divmod_op(1, ops.rdivmod, ser)

def test_divmod_series_array(self, data, data_for_twos):
ser = pd.Series(data)
Expand Down
16 changes: 10 additions & 6 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import decimal
import operator

Expand Down Expand Up @@ -311,8 +313,14 @@ def test_astype_dispatches(frame):


class TestArithmeticOps(base.BaseArithmeticOpsTests):
def check_opname(self, s, op_name, other, exc=None):
super().check_opname(s, op_name, other, exc=None)
series_scalar_exc = None
frame_scalar_exc = None
series_array_exc = None

def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
return None

def test_arith_series_with_array(self, data, all_arithmetic_operators):
op_name = all_arithmetic_operators
Expand All @@ -336,10 +344,6 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators):
context.traps[decimal.DivisionByZero] = divbyzerotrap
context.traps[decimal.InvalidOperation] = invalidoptrap

def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
# We implement divmod
super()._check_divmod_op(s, op, other, exc=None)


class TestComparisonOps(base.BaseComparisonOpsTests):
def test_compare_scalar(self, data, comparison_op):
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,6 @@ def test_divmod_series_array(self):
# skipping because it is not implemented
super().test_divmod_series_array()

def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
return super()._check_divmod_op(s, op, other, exc=TypeError)


class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests):
pass
Expand Down
34 changes: 17 additions & 17 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
classes (if they are relevant for the extension interface for all dtypes), or
be added to the array-specific tests in `pandas/tests/arrays/`.
"""
from __future__ import annotations

from datetime import (
date,
datetime,
Expand Down Expand Up @@ -964,16 +966,26 @@ def _is_temporal_supported(self, opname, pa_dtype):
and pa.types.is_temporal(pa_dtype)
)

def _get_scalar_exception(self, opname, pa_dtype):
arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)
if opname in {
def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
if op_name in ("__divmod__", "__rdivmod__"):
return self.divmod_exc

dtype = tm.get_dtype(obj)
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
# attribute "pyarrow_dtype"
pa_dtype = dtype.pyarrow_dtype # type: ignore[union-attr]

arrow_temporal_supported = self._is_temporal_supported(op_name, pa_dtype)
if op_name in {
"__mod__",
"__rmod__",
}:
exc = NotImplementedError
elif arrow_temporal_supported:
exc = None
elif opname in ["__add__", "__radd__"] and (
elif op_name in ["__add__", "__radd__"] and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
exc = None
Expand Down Expand Up @@ -1060,10 +1072,6 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request)
):
pytest.skip("Skip testing Python string formatting")

self.series_scalar_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
if mark is not None:
request.node.add_marker(mark)
Expand All @@ -1078,10 +1086,6 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
):
pytest.skip("Skip testing Python string formatting")

self.frame_scalar_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
if mark is not None:
request.node.add_marker(mark)
Expand All @@ -1091,10 +1095,6 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
pa_dtype = data.dtype.pyarrow_dtype

self.series_array_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

if (
all_arithmetic_operators
in (
Expand Down Expand Up @@ -1124,7 +1124,7 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
# since ser.iloc[0] is a python scalar
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))

self.check_opname(ser, op_name, other, exc=self.series_array_exc)
self.check_opname(ser, op_name, other)

def test_add_series_with_extension_array(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
Expand Down
22 changes: 5 additions & 17 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,14 @@ class TestMissing(base.BaseMissingTests):
class TestArithmeticOps(base.BaseArithmeticOpsTests):
implements = {"__sub__", "__rsub__"}

def check_opname(self, s, op_name, other, exc=None):
# overwriting to indicate ops don't raise an error
exc = None
def _get_expected_exception(self, op_name, obj, other):
if op_name.strip("_").lstrip("r") in ["pow", "truediv", "floordiv"]:
# match behavior with non-masked bool dtype
exc = NotImplementedError
return NotImplementedError
elif op_name in self.implements:
# exception message would include "numpy boolean subtract""
exc = TypeError

super().check_opname(s, op_name, other, exc=exc)
return TypeError
return None

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
if op_name in (
Expand Down Expand Up @@ -170,18 +167,9 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
def test_divmod_series_array(self, data, data_for_twos):
super().test_divmod_series_array(data, data_for_twos)

@pytest.mark.xfail(
reason="Inconsistency between floordiv and divmod; we raise for floordiv "
"but not for divmod. This matches what we do for non-masked bool dtype."
)
def test_divmod(self, data):
super().test_divmod(data)


class TestComparisonOps(base.BaseComparisonOpsTests):
def check_opname(self, s, op_name, other, exc=None):
# overwriting to indicate ops don't raise an error
super().check_opname(s, op_name, other, exc=None)
pass


class TestReshaping(base.BaseReshapingTests):
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,6 @@ def test_divmod_series_array(self):
# skipping because it is not implemented
pass

def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
return super()._check_divmod_op(s, op, other, exc=TypeError)


class TestComparisonOps(base.BaseComparisonOpsTests):
def _compare_other(self, s, data, op, other):
Expand Down
28 changes: 4 additions & 24 deletions pandas/tests/extension/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,10 @@ class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests):
class TestArithmeticOps(BaseDatetimeTests, base.BaseArithmeticOpsTests):
implements = {"__sub__", "__rsub__"}

def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
# frame & scalar
if all_arithmetic_operators in self.implements:
df = pd.DataFrame({"A": data})
self.check_opname(df, all_arithmetic_operators, data[0], exc=None)
else:
# ... but not the rest.
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)

def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
if all_arithmetic_operators in self.implements:
ser = pd.Series(data)
self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None)
else:
# ... but not the rest.
super().test_arith_series_with_scalar(data, all_arithmetic_operators)
def _get_expected_exception(self, op_name, obj, other):
if op_name in self.implements:
return None
return super()._get_expected_exception(op_name, obj, other)

def test_add_series_with_extension_array(self, data):
# Datetime + Datetime not implemented
Expand All @@ -154,14 +142,6 @@ def test_add_series_with_extension_array(self, data):
with pytest.raises(TypeError, match=msg):
ser + data

def test_arith_series_with_array(self, data, all_arithmetic_operators):
if all_arithmetic_operators in self.implements:
ser = pd.Series(data)
self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None)
else:
# ... but not the rest.
super().test_arith_series_with_scalar(data, all_arithmetic_operators)

def test_divmod_series_array(self):
# GH 23287
# skipping because it is not implemented
Expand Down
17 changes: 8 additions & 9 deletions pandas/tests/extension/test_masked_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,20 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
expected = expected.astype(sdtype)
return expected

def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
# overwriting to indicate ops don't raise an error
super().check_opname(ser, op_name, other, exc=None)

def _check_divmod_op(self, ser: pd.Series, op, other, exc=None):
super()._check_divmod_op(ser, op, other, None)
series_scalar_exc = None
series_array_exc = None
frame_scalar_exc = None
divmod_exc = None


class TestComparisonOps(base.BaseComparisonOpsTests):
series_scalar_exc = None
series_array_exc = None
frame_scalar_exc = None

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
return pointwise_result.astype("boolean")

def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
super().check_opname(ser, op_name, other, exc=None)

def _compare_other(self, ser: pd.Series, data, op, other):
op_name = f"__{op.__name__}__"
self.check_opname(ser, op_name, other)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_divmod(self, data):
@skip_nested
def test_divmod_series_array(self, data):
ser = pd.Series(data)
self._check_divmod_op(ser, divmod, data, exc=None)
self._check_divmod_op(ser, divmod, data)

@skip_nested
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
Expand Down
Loading