Skip to content

TST: parametrize arithmetic tests #44395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions pandas/tests/arithmetic/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,26 @@
array,
)
import pandas._testing as tm
from pandas.core.arrays import PandasArray
from pandas.core.arrays import (
BooleanArray,
PandasArray,
)


def assert_cannot_add(left, right, msg="cannot add"):
"""
Helper to assert that left and right cannot be added.

Parameters
----------
left : object
right : object
msg : str, default "cannot add"
"""
with pytest.raises(TypeError, match=msg):
left + right
with pytest.raises(TypeError, match=msg):
right + left


def assert_invalid_addsub_type(left, right, msg=None):
Expand Down Expand Up @@ -79,21 +98,29 @@ def xbox2(x):
# just exclude PandasArray[bool]
if isinstance(x, PandasArray):
return x._ndarray
if isinstance(x, BooleanArray):
# NB: we are assuming no pd.NAs for now
return x.astype(bool)
return x

# rev_box: box to use for reversed comparisons
rev_box = xbox
if isinstance(right, Index) and isinstance(left, Series):
rev_box = np.array

result = xbox2(left == right)
expected = xbox(np.zeros(result.shape, dtype=np.bool_))

tm.assert_equal(result, expected)

result = xbox2(right == left)
tm.assert_equal(result, expected)
tm.assert_equal(result, rev_box(expected))

result = xbox2(left != right)
tm.assert_equal(result, ~expected)

result = xbox2(right != left)
tm.assert_equal(result, ~expected)
tm.assert_equal(result, rev_box(~expected))

msg = "|".join(
[
Expand Down
104 changes: 30 additions & 74 deletions pandas/tests/arithmetic/test_datetime64.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from pandas.core.ops import roperator
from pandas.tests.arithmetic.common import (
assert_cannot_add,
assert_invalid_addsub_type,
assert_invalid_comparison,
get_upcast_box,
Expand Down Expand Up @@ -99,6 +100,7 @@ def test_dt64arr_cmp_scalar_invalid(self, other, tz_naive_fixture, box_with_arra
@pytest.mark.parametrize(
"other",
[
# GH#4968 invalid date/int comparisons
list(range(10)),
np.arange(10),
np.arange(10).astype(np.float32),
Expand All @@ -111,13 +113,14 @@ def test_dt64arr_cmp_scalar_invalid(self, other, tz_naive_fixture, box_with_arra
pd.period_range("1971-01-01", freq="D", periods=10).astype(object),
],
)
def test_dt64arr_cmp_arraylike_invalid(self, other, tz_naive_fixture):
# We don't parametrize this over box_with_array because listlike
# other plays poorly with assert_invalid_comparison reversed checks
def test_dt64arr_cmp_arraylike_invalid(
self, other, tz_naive_fixture, box_with_array
):
tz = tz_naive_fixture

dta = date_range("1970-01-01", freq="ns", periods=10, tz=tz)._data
assert_invalid_comparison(dta, other, tm.to_array)
obj = tm.box_expected(dta, box_with_array)
assert_invalid_comparison(obj, other, box_with_array)

def test_dt64arr_cmp_mixed_invalid(self, tz_naive_fixture):
tz = tz_naive_fixture
Expand Down Expand Up @@ -215,18 +218,6 @@ def test_nat_comparisons(

tm.assert_series_equal(result, expected)

def test_comparison_invalid(self, tz_naive_fixture, box_with_array):
# GH#4968
# invalid date/int comparisons
tz = tz_naive_fixture
ser = Series(range(5))
ser2 = Series(date_range("20010101", periods=5, tz=tz))

ser = tm.box_expected(ser, box_with_array)
ser2 = tm.box_expected(ser2, box_with_array)

assert_invalid_comparison(ser, ser2, box_with_array)

@pytest.mark.parametrize(
"data",
[
Expand Down Expand Up @@ -315,8 +306,8 @@ def test_timestamp_compare_series(self, left, right):
tm.assert_series_equal(result, expected)

# Compare to NaT with series containing NaT
expected = left_f(s_nat, Timestamp("nat"))
result = right_f(Timestamp("nat"), s_nat)
expected = left_f(s_nat, NaT)
result = right_f(NaT, s_nat)
tm.assert_series_equal(result, expected)

def test_dt64arr_timestamp_equality(self, box_with_array):
Expand Down Expand Up @@ -832,17 +823,6 @@ def test_dt64arr_add_timedeltalike_scalar(
result = rng + two_hours
tm.assert_equal(result, expected)

def test_dt64arr_iadd_timedeltalike_scalar(
self, tz_naive_fixture, two_hours, box_with_array
):
tz = tz_naive_fixture

rng = date_range("2000-01-01", "2000-02-01", tz=tz)
expected = date_range("2000-01-01 02:00", "2000-02-01 02:00", tz=tz)

rng = tm.box_expected(rng, box_with_array)
expected = tm.box_expected(expected, box_with_array)

rng += two_hours
tm.assert_equal(rng, expected)

Expand All @@ -860,17 +840,6 @@ def test_dt64arr_sub_timedeltalike_scalar(
result = rng - two_hours
tm.assert_equal(result, expected)

def test_dt64arr_isub_timedeltalike_scalar(
self, tz_naive_fixture, two_hours, box_with_array
):
tz = tz_naive_fixture

rng = date_range("2000-01-01", "2000-02-01", tz=tz)
expected = date_range("1999-12-31 22:00", "2000-01-31 22:00", tz=tz)

rng = tm.box_expected(rng, box_with_array)
expected = tm.box_expected(expected, box_with_array)

rng -= two_hours
tm.assert_equal(rng, expected)

Expand Down Expand Up @@ -1071,21 +1040,14 @@ def test_dt64arr_add_dt64ndarray_raises(self, tz_naive_fixture, box_with_array):
dt64vals = dti.values

dtarr = tm.box_expected(dti, box_with_array)
msg = "cannot add"
with pytest.raises(TypeError, match=msg):
dtarr + dt64vals
with pytest.raises(TypeError, match=msg):
dt64vals + dtarr
assert_cannot_add(dtarr, dt64vals)

def test_dt64arr_add_timestamp_raises(self, box_with_array):
# GH#22163 ensure DataFrame doesn't cast Timestamp to i8
idx = DatetimeIndex(["2011-01-01", "2011-01-02"])
ts = idx[0]
idx = tm.box_expected(idx, box_with_array)
msg = "cannot add"
with pytest.raises(TypeError, match=msg):
idx + Timestamp("2011-01-01")
with pytest.raises(TypeError, match=msg):
Timestamp("2011-01-01") + idx
assert_cannot_add(idx, ts)

# -------------------------------------------------------------
# Other Invalid Addition/Subtraction
Expand Down Expand Up @@ -1267,13 +1229,12 @@ def test_dti_add_tick_tzaware(self, tz_aware_fixture, box_with_array):
dates = tm.box_expected(dates, box_with_array)
expected = tm.box_expected(expected, box_with_array)

# TODO: parametrize over the scalar being added? radd? sub?
offset = dates + pd.offsets.Hour(5)
tm.assert_equal(offset, expected)
offset = dates + np.timedelta64(5, "h")
tm.assert_equal(offset, expected)
offset = dates + timedelta(hours=5)
tm.assert_equal(offset, expected)
# TODO: sub?
for scalar in [pd.offsets.Hour(5), np.timedelta64(5, "h"), timedelta(hours=5)]:
offset = dates + scalar
tm.assert_equal(offset, expected)
offset = scalar + dates
tm.assert_equal(offset, expected)

# -------------------------------------------------------------
# RelativeDelta DateOffsets
Expand Down Expand Up @@ -1941,30 +1902,24 @@ def test_dt64_mul_div_numeric_invalid(self, one, dt64_series):
one / dt64_series

# TODO: parametrize over box
@pytest.mark.parametrize("op", ["__add__", "__radd__", "__sub__", "__rsub__"])
def test_dt64_series_add_intlike(self, tz_naive_fixture, op):
def test_dt64_series_add_intlike(self, tz_naive_fixture):
# GH#19123
tz = tz_naive_fixture
dti = DatetimeIndex(["2016-01-02", "2016-02-03", "NaT"], tz=tz)
ser = Series(dti)

other = Series([20, 30, 40], dtype="uint8")

method = getattr(ser, op)
msg = "|".join(
[
"Addition/subtraction of integers and integer-arrays",
"cannot subtract .* from ndarray",
]
)
with pytest.raises(TypeError, match=msg):
method(1)
with pytest.raises(TypeError, match=msg):
method(other)
with pytest.raises(TypeError, match=msg):
method(np.array(other))
with pytest.raises(TypeError, match=msg):
method(pd.Index(other))
assert_invalid_addsub_type(ser, 1, msg)
assert_invalid_addsub_type(ser, other, msg)
assert_invalid_addsub_type(ser, np.array(other), msg)
assert_invalid_addsub_type(ser, pd.Index(other), msg)

# -------------------------------------------------------------
# Timezone-Centric Tests
Expand Down Expand Up @@ -2062,7 +2017,9 @@ def test_dti_add_intarray_tick(self, int_holder, freq):
dti = date_range("2016-01-01", periods=2, freq=freq)
other = int_holder([4, -1])

msg = "Addition/subtraction of integers|cannot subtract DatetimeArray from"
msg = "|".join(
["Addition/subtraction of integers", "cannot subtract DatetimeArray from"]
)
assert_invalid_addsub_type(dti, other, msg)

@pytest.mark.parametrize("freq", ["W", "M", "MS", "Q"])
Expand All @@ -2072,7 +2029,9 @@ def test_dti_add_intarray_non_tick(self, int_holder, freq):
dti = date_range("2016-01-01", periods=2, freq=freq)
other = int_holder([4, -1])

msg = "Addition/subtraction of integers|cannot subtract DatetimeArray from"
msg = "|".join(
["Addition/subtraction of integers", "cannot subtract DatetimeArray from"]
)
assert_invalid_addsub_type(dti, other, msg)

@pytest.mark.parametrize("int_holder", [np.array, pd.Index])
Expand Down Expand Up @@ -2222,10 +2181,7 @@ def test_add_datetimelike_and_dtarr(self, box_with_array, addend, tz):
dtarr = tm.box_expected(dti, box_with_array)
msg = "cannot add DatetimeArray and"

with pytest.raises(TypeError, match=msg):
dtarr + addend
with pytest.raises(TypeError, match=msg):
addend + dtarr
assert_cannot_add(dtarr, addend, msg)

# -------------------------------------------------------------

Expand Down
21 changes: 5 additions & 16 deletions pandas/tests/arithmetic/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
UInt64Index,
)
from pandas.core.computation import expressions as expr
from pandas.tests.arithmetic.common import assert_invalid_comparison


@pytest.fixture(params=[Index, Series, tm.to_array])
Expand Down Expand Up @@ -84,25 +85,13 @@ def test_operator_series_comparison_zerorank(self):
expected = 0.0 > Series([1, 2, 3])
tm.assert_series_equal(result, expected)

def test_df_numeric_cmp_dt64_raises(self):
def test_df_numeric_cmp_dt64_raises(self, box_with_array):
# GH#8932, GH#22163
ts = pd.Timestamp.now()
df = pd.DataFrame({"x": range(5)})
obj = np.array(range(5))
obj = tm.box_expected(obj, box_with_array)

msg = (
"'[<>]' not supported between instances of 'numpy.ndarray' and 'Timestamp'"
)
with pytest.raises(TypeError, match=msg):
df > ts
with pytest.raises(TypeError, match=msg):
df < ts
with pytest.raises(TypeError, match=msg):
ts < df
with pytest.raises(TypeError, match=msg):
ts > df

assert not (df == ts).any().any()
assert (df != ts).all().all()
assert_invalid_comparison(obj, ts, box_with_array)

def test_compare_invalid(self):
# GH#8058
Expand Down
28 changes: 15 additions & 13 deletions pandas/tests/arithmetic/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@


class TestObjectComparisons:
def test_comparison_object_numeric_nas(self):
def test_comparison_object_numeric_nas(self, comparison_op):
ser = Series(np.random.randn(10), dtype=object)
shifted = ser.shift(2)

ops = ["lt", "le", "gt", "ge", "eq", "ne"]
for op in ops:
func = getattr(operator, op)
func = comparison_op

result = func(ser, shifted)
expected = func(ser.astype(float), shifted.astype(float))
tm.assert_series_equal(result, expected)
result = func(ser, shifted)
expected = func(ser.astype(float), shifted.astype(float))
tm.assert_series_equal(result, expected)

def test_object_comparisons(self):
ser = Series(["a", "b", np.nan, "c", "a"])
Expand Down Expand Up @@ -141,11 +139,13 @@ def test_objarr_radd_str_invalid(self, dtype, data, box_with_array):
ser = Series(data, dtype=dtype)

ser = tm.box_expected(ser, box_with_array)
msg = (
"can only concatenate str|"
"did not contain a loop with signature matching types|"
"unsupported operand type|"
"must be str"
msg = "|".join(
[
"can only concatenate str",
"did not contain a loop with signature matching types",
"unsupported operand type",
"must be str",
]
)
with pytest.raises(TypeError, match=msg):
"foo_" + ser
Expand All @@ -159,7 +159,9 @@ def test_objarr_add_invalid(self, op, box_with_array):
obj_ser.name = "objects"

obj_ser = tm.box_expected(obj_ser, box)
msg = "can only concatenate str|unsupported operand type|must be str"
msg = "|".join(
["can only concatenate str", "unsupported operand type", "must be str"]
)
with pytest.raises(Exception, match=msg):
op(obj_ser, 1)
with pytest.raises(Exception, match=msg):
Expand Down
Loading