Skip to content

REF: avoid monkeypatch in arrow tests #54361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 40 additions & 52 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,17 +863,22 @@ def get_op_from_name(self, op_name):
short_opname = op_name.strip("_")
if short_opname == "rtruediv":
# use the numpy version that won't raise on division by zero
return lambda x, y: np.divide(y, x)

def rtruediv(x, y):
return np.divide(y, x)

return rtruediv
elif short_opname == "rfloordiv":
return lambda x, y: np.floor_divide(y, x)

return tm.get_op_from_name(op_name)

def _patch_combine(self, obj, other, op):
def _combine(self, obj, other, op):
# BaseOpsUtil._combine can upcast expected dtype
# (because it generates expected on python scalars)
# while ArrowExtensionArray maintains original type
expected = base.BaseArithmeticOpsTests._combine(self, obj, other, op)

was_frame = False
if isinstance(expected, pd.DataFrame):
was_frame = True
Expand All @@ -883,10 +888,37 @@ def _patch_combine(self, obj, other, op):
expected_data = expected
original_dtype = obj.dtype

orig_pa_type = original_dtype.pyarrow_dtype
if not was_frame and isinstance(other, pd.Series):
# i.e. test_arith_series_with_array
if not (
pa.types.is_floating(orig_pa_type)
or (
pa.types.is_integer(orig_pa_type)
and op.__name__ not in ["truediv", "rtruediv"]
)
or pa.types.is_duration(orig_pa_type)
or pa.types.is_timestamp(orig_pa_type)
or pa.types.is_date(orig_pa_type)
or pa.types.is_decimal(orig_pa_type)
):
# base class _combine always returns int64, while
# ArrowExtensionArray does not upcast
return expected
elif not (
(op is operator.floordiv and pa.types.is_integer(orig_pa_type))
or pa.types.is_duration(orig_pa_type)
or pa.types.is_timestamp(orig_pa_type)
or pa.types.is_date(orig_pa_type)
or pa.types.is_decimal(orig_pa_type)
):
# base class _combine always returns int64, while
# ArrowExtensionArray does not upcast
return expected

pa_expected = pa.array(expected_data._values)

if pa.types.is_duration(pa_expected.type):
orig_pa_type = original_dtype.pyarrow_dtype
if pa.types.is_date(orig_pa_type):
if pa.types.is_date64(orig_pa_type):
# TODO: why is this different vs date32?
Expand All @@ -907,7 +939,7 @@ def _patch_combine(self, obj, other, op):
pa_expected = pa_expected.cast(f"duration[{unit}]")

elif pa.types.is_decimal(pa_expected.type) and pa.types.is_decimal(
original_dtype.pyarrow_dtype
orig_pa_type
):
# decimal precision can resize in the result type depending on data
# just compare the float values
Expand All @@ -929,7 +961,7 @@ def _patch_combine(self, obj, other, op):
return expected.astype(alt_dtype)

else:
pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype)
pa_expected = pa_expected.cast(orig_pa_type)

pd_expected = type(expected_data._values)(pa_expected)
if was_frame:
Expand Down Expand Up @@ -1043,9 +1075,7 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):

return mark

def test_arith_series_with_scalar(
self, data, all_arithmetic_operators, request, monkeypatch
):
def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request):
pa_dtype = data.dtype.pyarrow_dtype

if all_arithmetic_operators == "__rmod__" and (
Expand All @@ -1061,24 +1091,9 @@ def test_arith_series_with_scalar(
if mark is not None:
request.node.add_marker(mark)

if (
(
all_arithmetic_operators == "__floordiv__"
and pa.types.is_integer(pa_dtype)
)
or pa.types.is_duration(pa_dtype)
or pa.types.is_timestamp(pa_dtype)
or pa.types.is_date(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
# not upcast
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
super().test_arith_series_with_scalar(data, all_arithmetic_operators)

def test_arith_frame_with_scalar(
self, data, all_arithmetic_operators, request, monkeypatch
):
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
pa_dtype = data.dtype.pyarrow_dtype

if all_arithmetic_operators == "__rmod__" and (
Expand All @@ -1094,24 +1109,9 @@ def test_arith_frame_with_scalar(
if mark is not None:
request.node.add_marker(mark)

if (
(
all_arithmetic_operators == "__floordiv__"
and pa.types.is_integer(pa_dtype)
)
or pa.types.is_duration(pa_dtype)
or pa.types.is_timestamp(pa_dtype)
or pa.types.is_date(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
# not upcast
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)

def test_arith_series_with_array(
self, data, all_arithmetic_operators, request, monkeypatch
):
def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
pa_dtype = data.dtype.pyarrow_dtype

self.series_array_exc = self._get_scalar_exception(
Expand Down Expand Up @@ -1147,18 +1147,6 @@ def test_arith_series_with_array(
# since ser.iloc[0] is a python scalar
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))

if (
pa.types.is_floating(pa_dtype)
or (
pa.types.is_integer(pa_dtype)
and all_arithmetic_operators not in ["__truediv__", "__rtruediv__"]
)
or pa.types.is_duration(pa_dtype)
or pa.types.is_timestamp(pa_dtype)
or pa.types.is_date(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
self.check_opname(ser, op_name, other, exc=self.series_array_exc)

def test_add_series_with_extension_array(self, data, request):
Expand Down