Skip to content

Commit 7a31ca8

Browse files
jbrockmendelnickleus27
authored andcommitted
TST: parametrize arithmetic tests (pandas-dev#44395)
1 parent a23f5df commit 7a31ca8

File tree

6 files changed

+175
-227
lines changed

6 files changed

+175
-227
lines changed

pandas/tests/arithmetic/common.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,26 @@
1111
array,
1212
)
1313
import pandas._testing as tm
14-
from pandas.core.arrays import PandasArray
14+
from pandas.core.arrays import (
15+
BooleanArray,
16+
PandasArray,
17+
)
18+
19+
20+
def assert_cannot_add(left, right, msg="cannot add"):
21+
"""
22+
Helper to assert that left and right cannot be added.
23+
24+
Parameters
25+
----------
26+
left : object
27+
right : object
28+
msg : str, default "cannot add"
29+
"""
30+
with pytest.raises(TypeError, match=msg):
31+
left + right
32+
with pytest.raises(TypeError, match=msg):
33+
right + left
1534

1635

1736
def assert_invalid_addsub_type(left, right, msg=None):
@@ -79,21 +98,29 @@ def xbox2(x):
7998
# just exclude PandasArray[bool]
8099
if isinstance(x, PandasArray):
81100
return x._ndarray
101+
if isinstance(x, BooleanArray):
102+
# NB: we are assuming no pd.NAs for now
103+
return x.astype(bool)
82104
return x
83105

106+
# rev_box: box to use for reversed comparisons
107+
rev_box = xbox
108+
if isinstance(right, Index) and isinstance(left, Series):
109+
rev_box = np.array
110+
84111
result = xbox2(left == right)
85112
expected = xbox(np.zeros(result.shape, dtype=np.bool_))
86113

87114
tm.assert_equal(result, expected)
88115

89116
result = xbox2(right == left)
90-
tm.assert_equal(result, expected)
117+
tm.assert_equal(result, rev_box(expected))
91118

92119
result = xbox2(left != right)
93120
tm.assert_equal(result, ~expected)
94121

95122
result = xbox2(right != left)
96-
tm.assert_equal(result, ~expected)
123+
tm.assert_equal(result, rev_box(~expected))
97124

98125
msg = "|".join(
99126
[

pandas/tests/arithmetic/test_datetime64.py

+30-74
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from pandas.core.ops import roperator
4343
from pandas.tests.arithmetic.common import (
44+
assert_cannot_add,
4445
assert_invalid_addsub_type,
4546
assert_invalid_comparison,
4647
get_upcast_box,
@@ -99,6 +100,7 @@ def test_dt64arr_cmp_scalar_invalid(self, other, tz_naive_fixture, box_with_arra
99100
@pytest.mark.parametrize(
100101
"other",
101102
[
103+
# GH#4968 invalid date/int comparisons
102104
list(range(10)),
103105
np.arange(10),
104106
np.arange(10).astype(np.float32),
@@ -111,13 +113,14 @@ def test_dt64arr_cmp_scalar_invalid(self, other, tz_naive_fixture, box_with_arra
111113
pd.period_range("1971-01-01", freq="D", periods=10).astype(object),
112114
],
113115
)
114-
def test_dt64arr_cmp_arraylike_invalid(self, other, tz_naive_fixture):
115-
# We don't parametrize this over box_with_array because listlike
116-
# other plays poorly with assert_invalid_comparison reversed checks
116+
def test_dt64arr_cmp_arraylike_invalid(
117+
self, other, tz_naive_fixture, box_with_array
118+
):
117119
tz = tz_naive_fixture
118120

119121
dta = date_range("1970-01-01", freq="ns", periods=10, tz=tz)._data
120-
assert_invalid_comparison(dta, other, tm.to_array)
122+
obj = tm.box_expected(dta, box_with_array)
123+
assert_invalid_comparison(obj, other, box_with_array)
121124

122125
def test_dt64arr_cmp_mixed_invalid(self, tz_naive_fixture):
123126
tz = tz_naive_fixture
@@ -215,18 +218,6 @@ def test_nat_comparisons(
215218

216219
tm.assert_series_equal(result, expected)
217220

218-
def test_comparison_invalid(self, tz_naive_fixture, box_with_array):
219-
# GH#4968
220-
# invalid date/int comparisons
221-
tz = tz_naive_fixture
222-
ser = Series(range(5))
223-
ser2 = Series(date_range("20010101", periods=5, tz=tz))
224-
225-
ser = tm.box_expected(ser, box_with_array)
226-
ser2 = tm.box_expected(ser2, box_with_array)
227-
228-
assert_invalid_comparison(ser, ser2, box_with_array)
229-
230221
@pytest.mark.parametrize(
231222
"data",
232223
[
@@ -315,8 +306,8 @@ def test_timestamp_compare_series(self, left, right):
315306
tm.assert_series_equal(result, expected)
316307

317308
# Compare to NaT with series containing NaT
318-
expected = left_f(s_nat, Timestamp("nat"))
319-
result = right_f(Timestamp("nat"), s_nat)
309+
expected = left_f(s_nat, NaT)
310+
result = right_f(NaT, s_nat)
320311
tm.assert_series_equal(result, expected)
321312

322313
def test_dt64arr_timestamp_equality(self, box_with_array):
@@ -832,17 +823,6 @@ def test_dt64arr_add_timedeltalike_scalar(
832823
result = rng + two_hours
833824
tm.assert_equal(result, expected)
834825

835-
def test_dt64arr_iadd_timedeltalike_scalar(
836-
self, tz_naive_fixture, two_hours, box_with_array
837-
):
838-
tz = tz_naive_fixture
839-
840-
rng = date_range("2000-01-01", "2000-02-01", tz=tz)
841-
expected = date_range("2000-01-01 02:00", "2000-02-01 02:00", tz=tz)
842-
843-
rng = tm.box_expected(rng, box_with_array)
844-
expected = tm.box_expected(expected, box_with_array)
845-
846826
rng += two_hours
847827
tm.assert_equal(rng, expected)
848828

@@ -860,17 +840,6 @@ def test_dt64arr_sub_timedeltalike_scalar(
860840
result = rng - two_hours
861841
tm.assert_equal(result, expected)
862842

863-
def test_dt64arr_isub_timedeltalike_scalar(
864-
self, tz_naive_fixture, two_hours, box_with_array
865-
):
866-
tz = tz_naive_fixture
867-
868-
rng = date_range("2000-01-01", "2000-02-01", tz=tz)
869-
expected = date_range("1999-12-31 22:00", "2000-01-31 22:00", tz=tz)
870-
871-
rng = tm.box_expected(rng, box_with_array)
872-
expected = tm.box_expected(expected, box_with_array)
873-
874843
rng -= two_hours
875844
tm.assert_equal(rng, expected)
876845

@@ -1071,21 +1040,14 @@ def test_dt64arr_add_dt64ndarray_raises(self, tz_naive_fixture, box_with_array):
10711040
dt64vals = dti.values
10721041

10731042
dtarr = tm.box_expected(dti, box_with_array)
1074-
msg = "cannot add"
1075-
with pytest.raises(TypeError, match=msg):
1076-
dtarr + dt64vals
1077-
with pytest.raises(TypeError, match=msg):
1078-
dt64vals + dtarr
1043+
assert_cannot_add(dtarr, dt64vals)
10791044

10801045
def test_dt64arr_add_timestamp_raises(self, box_with_array):
10811046
# GH#22163 ensure DataFrame doesn't cast Timestamp to i8
10821047
idx = DatetimeIndex(["2011-01-01", "2011-01-02"])
1048+
ts = idx[0]
10831049
idx = tm.box_expected(idx, box_with_array)
1084-
msg = "cannot add"
1085-
with pytest.raises(TypeError, match=msg):
1086-
idx + Timestamp("2011-01-01")
1087-
with pytest.raises(TypeError, match=msg):
1088-
Timestamp("2011-01-01") + idx
1050+
assert_cannot_add(idx, ts)
10891051

10901052
# -------------------------------------------------------------
10911053
# Other Invalid Addition/Subtraction
@@ -1267,13 +1229,12 @@ def test_dti_add_tick_tzaware(self, tz_aware_fixture, box_with_array):
12671229
dates = tm.box_expected(dates, box_with_array)
12681230
expected = tm.box_expected(expected, box_with_array)
12691231

1270-
# TODO: parametrize over the scalar being added? radd? sub?
1271-
offset = dates + pd.offsets.Hour(5)
1272-
tm.assert_equal(offset, expected)
1273-
offset = dates + np.timedelta64(5, "h")
1274-
tm.assert_equal(offset, expected)
1275-
offset = dates + timedelta(hours=5)
1276-
tm.assert_equal(offset, expected)
1232+
# TODO: sub?
1233+
for scalar in [pd.offsets.Hour(5), np.timedelta64(5, "h"), timedelta(hours=5)]:
1234+
offset = dates + scalar
1235+
tm.assert_equal(offset, expected)
1236+
offset = scalar + dates
1237+
tm.assert_equal(offset, expected)
12771238

12781239
# -------------------------------------------------------------
12791240
# RelativeDelta DateOffsets
@@ -1941,30 +1902,24 @@ def test_dt64_mul_div_numeric_invalid(self, one, dt64_series):
19411902
one / dt64_series
19421903

19431904
# TODO: parametrize over box
1944-
@pytest.mark.parametrize("op", ["__add__", "__radd__", "__sub__", "__rsub__"])
1945-
def test_dt64_series_add_intlike(self, tz_naive_fixture, op):
1905+
def test_dt64_series_add_intlike(self, tz_naive_fixture):
19461906
# GH#19123
19471907
tz = tz_naive_fixture
19481908
dti = DatetimeIndex(["2016-01-02", "2016-02-03", "NaT"], tz=tz)
19491909
ser = Series(dti)
19501910

19511911
other = Series([20, 30, 40], dtype="uint8")
19521912

1953-
method = getattr(ser, op)
19541913
msg = "|".join(
19551914
[
19561915
"Addition/subtraction of integers and integer-arrays",
19571916
"cannot subtract .* from ndarray",
19581917
]
19591918
)
1960-
with pytest.raises(TypeError, match=msg):
1961-
method(1)
1962-
with pytest.raises(TypeError, match=msg):
1963-
method(other)
1964-
with pytest.raises(TypeError, match=msg):
1965-
method(np.array(other))
1966-
with pytest.raises(TypeError, match=msg):
1967-
method(pd.Index(other))
1919+
assert_invalid_addsub_type(ser, 1, msg)
1920+
assert_invalid_addsub_type(ser, other, msg)
1921+
assert_invalid_addsub_type(ser, np.array(other), msg)
1922+
assert_invalid_addsub_type(ser, pd.Index(other), msg)
19681923

19691924
# -------------------------------------------------------------
19701925
# Timezone-Centric Tests
@@ -2062,7 +2017,9 @@ def test_dti_add_intarray_tick(self, int_holder, freq):
20622017
dti = date_range("2016-01-01", periods=2, freq=freq)
20632018
other = int_holder([4, -1])
20642019

2065-
msg = "Addition/subtraction of integers|cannot subtract DatetimeArray from"
2020+
msg = "|".join(
2021+
["Addition/subtraction of integers", "cannot subtract DatetimeArray from"]
2022+
)
20662023
assert_invalid_addsub_type(dti, other, msg)
20672024

20682025
@pytest.mark.parametrize("freq", ["W", "M", "MS", "Q"])
@@ -2072,7 +2029,9 @@ def test_dti_add_intarray_non_tick(self, int_holder, freq):
20722029
dti = date_range("2016-01-01", periods=2, freq=freq)
20732030
other = int_holder([4, -1])
20742031

2075-
msg = "Addition/subtraction of integers|cannot subtract DatetimeArray from"
2032+
msg = "|".join(
2033+
["Addition/subtraction of integers", "cannot subtract DatetimeArray from"]
2034+
)
20762035
assert_invalid_addsub_type(dti, other, msg)
20772036

20782037
@pytest.mark.parametrize("int_holder", [np.array, pd.Index])
@@ -2222,10 +2181,7 @@ def test_add_datetimelike_and_dtarr(self, box_with_array, addend, tz):
22222181
dtarr = tm.box_expected(dti, box_with_array)
22232182
msg = "cannot add DatetimeArray and"
22242183

2225-
with pytest.raises(TypeError, match=msg):
2226-
dtarr + addend
2227-
with pytest.raises(TypeError, match=msg):
2228-
addend + dtarr
2184+
assert_cannot_add(dtarr, addend, msg)
22292185

22302186
# -------------------------------------------------------------
22312187

pandas/tests/arithmetic/test_numeric.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
UInt64Index,
3030
)
3131
from pandas.core.computation import expressions as expr
32+
from pandas.tests.arithmetic.common import assert_invalid_comparison
3233

3334

3435
@pytest.fixture(params=[Index, Series, tm.to_array])
@@ -84,25 +85,13 @@ def test_operator_series_comparison_zerorank(self):
8485
expected = 0.0 > Series([1, 2, 3])
8586
tm.assert_series_equal(result, expected)
8687

87-
def test_df_numeric_cmp_dt64_raises(self):
88+
def test_df_numeric_cmp_dt64_raises(self, box_with_array):
8889
# GH#8932, GH#22163
8990
ts = pd.Timestamp.now()
90-
df = pd.DataFrame({"x": range(5)})
91+
obj = np.array(range(5))
92+
obj = tm.box_expected(obj, box_with_array)
9193

92-
msg = (
93-
"'[<>]' not supported between instances of 'numpy.ndarray' and 'Timestamp'"
94-
)
95-
with pytest.raises(TypeError, match=msg):
96-
df > ts
97-
with pytest.raises(TypeError, match=msg):
98-
df < ts
99-
with pytest.raises(TypeError, match=msg):
100-
ts < df
101-
with pytest.raises(TypeError, match=msg):
102-
ts > df
103-
104-
assert not (df == ts).any().any()
105-
assert (df != ts).all().all()
94+
assert_invalid_comparison(obj, ts, box_with_array)
10695

10796
def test_compare_invalid(self):
10897
# GH#8058

pandas/tests/arithmetic/test_object.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,15 @@
2121

2222

2323
class TestObjectComparisons:
24-
def test_comparison_object_numeric_nas(self):
24+
def test_comparison_object_numeric_nas(self, comparison_op):
2525
ser = Series(np.random.randn(10), dtype=object)
2626
shifted = ser.shift(2)
2727

28-
ops = ["lt", "le", "gt", "ge", "eq", "ne"]
29-
for op in ops:
30-
func = getattr(operator, op)
28+
func = comparison_op
3129

32-
result = func(ser, shifted)
33-
expected = func(ser.astype(float), shifted.astype(float))
34-
tm.assert_series_equal(result, expected)
30+
result = func(ser, shifted)
31+
expected = func(ser.astype(float), shifted.astype(float))
32+
tm.assert_series_equal(result, expected)
3533

3634
def test_object_comparisons(self):
3735
ser = Series(["a", "b", np.nan, "c", "a"])
@@ -141,11 +139,13 @@ def test_objarr_radd_str_invalid(self, dtype, data, box_with_array):
141139
ser = Series(data, dtype=dtype)
142140

143141
ser = tm.box_expected(ser, box_with_array)
144-
msg = (
145-
"can only concatenate str|"
146-
"did not contain a loop with signature matching types|"
147-
"unsupported operand type|"
148-
"must be str"
142+
msg = "|".join(
143+
[
144+
"can only concatenate str",
145+
"did not contain a loop with signature matching types",
146+
"unsupported operand type",
147+
"must be str",
148+
]
149149
)
150150
with pytest.raises(TypeError, match=msg):
151151
"foo_" + ser
@@ -159,7 +159,9 @@ def test_objarr_add_invalid(self, op, box_with_array):
159159
obj_ser.name = "objects"
160160

161161
obj_ser = tm.box_expected(obj_ser, box)
162-
msg = "can only concatenate str|unsupported operand type|must be str"
162+
msg = "|".join(
163+
["can only concatenate str", "unsupported operand type", "must be str"]
164+
)
163165
with pytest.raises(Exception, match=msg):
164166
op(obj_ser, 1)
165167
with pytest.raises(Exception, match=msg):

0 commit comments

Comments
 (0)