Skip to content

Commit a656d24

Browse files
jbrockmendeljreback
authored andcommitted
TST: parametrize arithmetic tests (#27847)
1 parent 6813d77 commit a656d24

File tree

4 files changed

+154
-230
lines changed

4 files changed

+154
-230
lines changed

pandas/tests/arithmetic/test_datetime64.py

+100-165
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,54 @@
3030
import pandas.util.testing as tm
3131

3232

33+
def assert_invalid_comparison(left, right, box):
34+
"""
35+
Assert that comparison operations with mismatched types behave correctly.
36+
37+
Parameters
38+
----------
39+
left : np.ndarray, ExtensionArray, Index, or Series
40+
right : object
41+
box : {pd.DataFrame, pd.Series, pd.Index, tm.to_array}
42+
"""
43+
# Not for tznaive-tzaware comparison
44+
45+
# Note: not quite the same as how we do this for tm.box_expected
46+
xbox = box if box is not pd.Index else np.array
47+
48+
result = left == right
49+
expected = xbox(np.zeros(result.shape, dtype=np.bool_))
50+
51+
tm.assert_equal(result, expected)
52+
53+
result = right == left
54+
tm.assert_equal(result, expected)
55+
56+
result = left != right
57+
tm.assert_equal(result, ~expected)
58+
59+
result = right != left
60+
tm.assert_equal(result, ~expected)
61+
62+
msg = "Invalid comparison between"
63+
with pytest.raises(TypeError, match=msg):
64+
left < right
65+
with pytest.raises(TypeError, match=msg):
66+
left <= right
67+
with pytest.raises(TypeError, match=msg):
68+
left > right
69+
with pytest.raises(TypeError, match=msg):
70+
left >= right
71+
with pytest.raises(TypeError, match=msg):
72+
right < left
73+
with pytest.raises(TypeError, match=msg):
74+
right <= left
75+
with pytest.raises(TypeError, match=msg):
76+
right > left
77+
with pytest.raises(TypeError, match=msg):
78+
right >= left
79+
80+
3381
def assert_all(obj):
3482
"""
3583
Test helper to call call obj.all() the appropriate number of times on
@@ -47,7 +95,7 @@ def assert_all(obj):
4795

4896
class TestDatetime64ArrayLikeComparisons:
4997
# Comparison tests for datetime64 vectors fully parametrized over
50-
# DataFrame/Series/DatetimeIndex/DateteimeArray. Ideally all comparison
98+
# DataFrame/Series/DatetimeIndex/DatetimeArray. Ideally all comparison
5199
# tests will eventually end up here.
52100

53101
def test_compare_zerodim(self, tz_naive_fixture, box_with_array):
@@ -59,36 +107,61 @@ def test_compare_zerodim(self, tz_naive_fixture, box_with_array):
59107

60108
other = np.array(dti.to_numpy()[0])
61109

62-
# FIXME: ValueError with transpose on tzaware
63-
dtarr = tm.box_expected(dti, box, transpose=False)
110+
dtarr = tm.box_expected(dti, box)
64111
result = dtarr <= other
65112
expected = np.array([True, False, False])
66-
expected = tm.box_expected(expected, xbox, transpose=False)
113+
expected = tm.box_expected(expected, xbox)
67114
tm.assert_equal(result, expected)
68115

116+
def test_dt64arr_cmp_date_invalid(self, tz_naive_fixture, box_with_array):
117+
# GH#19800, GH#19301 datetime.date comparison raises to
118+
# match DatetimeIndex/Timestamp. This also matches the behavior
119+
# of stdlib datetime.datetime
120+
tz = tz_naive_fixture
69121

70-
class TestDatetime64DataFrameComparison:
71-
@pytest.mark.parametrize(
72-
"timestamps",
73-
[
74-
[pd.Timestamp("2012-01-01 13:00:00+00:00")] * 2,
75-
[pd.Timestamp("2012-01-01 13:00:00")] * 2,
76-
],
77-
)
78-
def test_tz_aware_scalar_comparison(self, timestamps):
79-
# GH#15966
80-
df = pd.DataFrame({"test": timestamps})
81-
expected = pd.DataFrame({"test": [False, False]})
82-
tm.assert_frame_equal(df == -1, expected)
122+
dti = pd.date_range("20010101", periods=10, tz=tz)
123+
date = dti[0].to_pydatetime().date()
124+
125+
dtarr = tm.box_expected(dti, box_with_array)
126+
assert_invalid_comparison(dtarr, date, box_with_array)
83127

84-
def test_dt64_nat_comparison(self):
128+
@pytest.mark.parametrize("other", ["foo", -1, 99, 4.0, object(), timedelta(days=2)])
129+
def test_dt64arr_cmp_scalar_invalid(self, other, tz_naive_fixture, box_with_array):
130+
# GH#22074, GH#15966
131+
tz = tz_naive_fixture
132+
133+
rng = date_range("1/1/2000", periods=10, tz=tz)
134+
dtarr = tm.box_expected(rng, box_with_array)
135+
assert_invalid_comparison(dtarr, other, box_with_array)
136+
137+
@pytest.mark.parametrize("other", [None, np.nan])
138+
def test_dt64arr_cmp_na_scalar_invalid(
139+
self, other, tz_naive_fixture, box_with_array
140+
):
141+
# GH#19301
142+
tz = tz_naive_fixture
143+
dti = pd.date_range("2016-01-01", periods=2, tz=tz)
144+
dtarr = tm.box_expected(dti, box_with_array)
145+
assert_invalid_comparison(dtarr, other, box_with_array)
146+
147+
def test_dt64arr_nat_comparison(self, tz_naive_fixture, box_with_array):
85148
# GH#22242, GH#22163 DataFrame considered NaT == ts incorrectly
86-
ts = pd.Timestamp.now()
87-
df = pd.DataFrame([ts, pd.NaT])
88-
expected = pd.DataFrame([True, False])
149+
tz = tz_naive_fixture
150+
box = box_with_array
151+
xbox = box if box is not pd.Index else np.ndarray
152+
153+
ts = pd.Timestamp.now(tz)
154+
ser = pd.Series([ts, pd.NaT])
155+
156+
# FIXME: Can't transpose because that loses the tz dtype on
157+
# the NaT column
158+
obj = tm.box_expected(ser, box, transpose=False)
89159

90-
result = df == ts
91-
tm.assert_frame_equal(result, expected)
160+
expected = pd.Series([True, False], dtype=np.bool_)
161+
expected = tm.box_expected(expected, xbox, transpose=False)
162+
163+
result = obj == ts
164+
tm.assert_equal(result, expected)
92165

93166

94167
class TestDatetime64SeriesComparison:
@@ -142,35 +215,17 @@ def test_nat_comparisons(self, dtype, box, reverse, pair):
142215
expected = Series([False, False, True])
143216
tm.assert_series_equal(left <= right, expected)
144217

145-
def test_comparison_invalid(self, box_with_array):
218+
def test_comparison_invalid(self, tz_naive_fixture, box_with_array):
146219
# GH#4968
147220
# invalid date/int comparisons
148-
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
149-
221+
tz = tz_naive_fixture
150222
ser = Series(range(5))
151-
ser2 = Series(pd.date_range("20010101", periods=5))
223+
ser2 = Series(pd.date_range("20010101", periods=5, tz=tz))
152224

153225
ser = tm.box_expected(ser, box_with_array)
154226
ser2 = tm.box_expected(ser2, box_with_array)
155227

156-
for (x, y) in [(ser, ser2), (ser2, ser)]:
157-
158-
result = x == y
159-
expected = tm.box_expected([False] * 5, xbox)
160-
tm.assert_equal(result, expected)
161-
162-
result = x != y
163-
expected = tm.box_expected([True] * 5, xbox)
164-
tm.assert_equal(result, expected)
165-
msg = "Invalid comparison between"
166-
with pytest.raises(TypeError, match=msg):
167-
x >= y
168-
with pytest.raises(TypeError, match=msg):
169-
x > y
170-
with pytest.raises(TypeError, match=msg):
171-
x < y
172-
with pytest.raises(TypeError, match=msg):
173-
x <= y
228+
assert_invalid_comparison(ser, ser2, box_with_array)
174229

175230
@pytest.mark.parametrize(
176231
"data",
@@ -227,26 +282,6 @@ def test_series_comparison_scalars(self):
227282
expected = Series([x > val for x in series])
228283
tm.assert_series_equal(result, expected)
229284

230-
def test_dt64ser_cmp_date_invalid(self, box_with_array):
231-
# GH#19800 datetime.date comparison raises to
232-
# match DatetimeIndex/Timestamp. This also matches the behavior
233-
# of stdlib datetime.datetime
234-
235-
ser = pd.date_range("20010101", periods=10)
236-
date = ser[0].to_pydatetime().date()
237-
238-
ser = tm.box_expected(ser, box_with_array)
239-
assert_all(~(ser == date))
240-
assert_all(ser != date)
241-
with pytest.raises(TypeError):
242-
ser > date
243-
with pytest.raises(TypeError):
244-
ser < date
245-
with pytest.raises(TypeError):
246-
ser >= date
247-
with pytest.raises(TypeError):
248-
ser <= date
249-
250285
@pytest.mark.parametrize(
251286
"left,right", [("lt", "gt"), ("le", "ge"), ("eq", "eq"), ("ne", "ne")]
252287
)
@@ -388,57 +423,6 @@ def test_dti_cmp_datetimelike(self, other, tz_naive_fixture):
388423
expected = np.array([True, False])
389424
tm.assert_numpy_array_equal(result, expected)
390425

391-
def dt64arr_cmp_non_datetime(self, tz_naive_fixture, box_with_array):
392-
# GH#19301 by convention datetime.date is not considered comparable
393-
# to Timestamp or DatetimeIndex. This may change in the future.
394-
tz = tz_naive_fixture
395-
dti = pd.date_range("2016-01-01", periods=2, tz=tz)
396-
dtarr = tm.box_expected(dti, box_with_array)
397-
398-
other = datetime(2016, 1, 1).date()
399-
assert not (dtarr == other).any()
400-
assert (dtarr != other).all()
401-
with pytest.raises(TypeError):
402-
dtarr < other
403-
with pytest.raises(TypeError):
404-
dtarr <= other
405-
with pytest.raises(TypeError):
406-
dtarr > other
407-
with pytest.raises(TypeError):
408-
dtarr >= other
409-
410-
@pytest.mark.parametrize("other", [None, np.nan, pd.NaT])
411-
def test_dti_eq_null_scalar(self, other, tz_naive_fixture):
412-
# GH#19301
413-
tz = tz_naive_fixture
414-
dti = pd.date_range("2016-01-01", periods=2, tz=tz)
415-
assert not (dti == other).any()
416-
417-
@pytest.mark.parametrize("other", [None, np.nan, pd.NaT])
418-
def test_dti_ne_null_scalar(self, other, tz_naive_fixture):
419-
# GH#19301
420-
tz = tz_naive_fixture
421-
dti = pd.date_range("2016-01-01", periods=2, tz=tz)
422-
assert (dti != other).all()
423-
424-
@pytest.mark.parametrize("other", [None, np.nan])
425-
def test_dti_cmp_null_scalar_inequality(
426-
self, tz_naive_fixture, other, box_with_array
427-
):
428-
# GH#19301
429-
tz = tz_naive_fixture
430-
dti = pd.date_range("2016-01-01", periods=2, tz=tz)
431-
dtarr = tm.box_expected(dti, box_with_array)
432-
msg = "Invalid comparison between"
433-
with pytest.raises(TypeError, match=msg):
434-
dtarr < other
435-
with pytest.raises(TypeError, match=msg):
436-
dtarr <= other
437-
with pytest.raises(TypeError, match=msg):
438-
dtarr > other
439-
with pytest.raises(TypeError, match=msg):
440-
dtarr >= other
441-
442426
@pytest.mark.parametrize("dtype", [None, object])
443427
def test_dti_cmp_nat(self, dtype, box_with_array):
444428
if box_with_array is tm.to_array and dtype is object:
@@ -728,34 +712,6 @@ def test_dti_cmp_str(self, tz_naive_fixture):
728712
expected = np.array([True] * 10)
729713
tm.assert_numpy_array_equal(result, expected)
730714

731-
@pytest.mark.parametrize("other", ["foo", 99, 4.0, object(), timedelta(days=2)])
732-
def test_dt64arr_cmp_scalar_invalid(self, other, tz_naive_fixture, box_with_array):
733-
# GH#22074
734-
tz = tz_naive_fixture
735-
xbox = box_with_array if box_with_array is not pd.Index else np.ndarray
736-
737-
rng = date_range("1/1/2000", periods=10, tz=tz)
738-
rng = tm.box_expected(rng, box_with_array)
739-
740-
result = rng == other
741-
expected = np.array([False] * 10)
742-
expected = tm.box_expected(expected, xbox)
743-
tm.assert_equal(result, expected)
744-
745-
result = rng != other
746-
expected = np.array([True] * 10)
747-
expected = tm.box_expected(expected, xbox)
748-
tm.assert_equal(result, expected)
749-
msg = "Invalid comparison between"
750-
with pytest.raises(TypeError, match=msg):
751-
rng < other
752-
with pytest.raises(TypeError, match=msg):
753-
rng <= other
754-
with pytest.raises(TypeError, match=msg):
755-
rng > other
756-
with pytest.raises(TypeError, match=msg):
757-
rng >= other
758-
759715
def test_dti_cmp_list(self):
760716
rng = date_range("1/1/2000", periods=10)
761717

@@ -2576,24 +2532,3 @@ def test_shift_months(years, months):
25762532
raw = [x + pd.offsets.DateOffset(years=years, months=months) for x in dti]
25772533
expected = DatetimeIndex(raw)
25782534
tm.assert_index_equal(actual, expected)
2579-
2580-
2581-
# FIXME: this belongs in scalar tests
2582-
class SubDatetime(datetime):
2583-
pass
2584-
2585-
2586-
@pytest.mark.parametrize(
2587-
"lh,rh",
2588-
[
2589-
(SubDatetime(2000, 1, 1), Timedelta(hours=1)),
2590-
(Timedelta(hours=1), SubDatetime(2000, 1, 1)),
2591-
],
2592-
)
2593-
def test_dt_subclass_add_timedelta(lh, rh):
2594-
# GH 25851
2595-
# ensure that subclassed datetime works for
2596-
# Timedelta operations
2597-
result = lh + rh
2598-
expected = SubDatetime(2000, 1, 1, 1)
2599-
assert result == expected

pandas/tests/arithmetic/test_period.py

+8-18
Original file line numberDiff line numberDiff line change
@@ -573,12 +573,19 @@ def test_parr_add_sub_float_raises(self, op, other, box_with_array):
573573
@pytest.mark.parametrize(
574574
"other",
575575
[
576+
# datetime scalars
576577
pd.Timestamp.now(),
577578
pd.Timestamp.now().to_pydatetime(),
578579
pd.Timestamp.now().to_datetime64(),
580+
# datetime-like arrays
581+
pd.date_range("2016-01-01", periods=3, freq="H"),
582+
pd.date_range("2016-01-01", periods=3, tz="Europe/Brussels"),
583+
pd.date_range("2016-01-01", periods=3, freq="S")._data,
584+
pd.date_range("2016-01-01", periods=3, tz="Asia/Tokyo")._data,
585+
# Miscellaneous invalid types
579586
],
580587
)
581-
def test_parr_add_sub_datetime_scalar(self, other, box_with_array):
588+
def test_parr_add_sub_invalid(self, other, box_with_array):
582589
# GH#23215
583590
rng = pd.period_range("1/1/2000", freq="D", periods=3)
584591
rng = tm.box_expected(rng, box_with_array)
@@ -595,23 +602,6 @@ def test_parr_add_sub_datetime_scalar(self, other, box_with_array):
595602
# -----------------------------------------------------------------
596603
# __add__/__sub__ with ndarray[datetime64] and ndarray[timedelta64]
597604

598-
def test_parr_add_sub_dt64_array_raises(self, box_with_array):
599-
rng = pd.period_range("1/1/2000", freq="D", periods=3)
600-
dti = pd.date_range("2016-01-01", periods=3)
601-
dtarr = dti.values
602-
603-
rng = tm.box_expected(rng, box_with_array)
604-
605-
with pytest.raises(TypeError):
606-
rng + dtarr
607-
with pytest.raises(TypeError):
608-
dtarr + rng
609-
610-
with pytest.raises(TypeError):
611-
rng - dtarr
612-
with pytest.raises(TypeError):
613-
dtarr - rng
614-
615605
def test_pi_add_sub_td64_array_non_tick_raises(self):
616606
rng = pd.period_range("1/1/2000", freq="Q", periods=3)
617607
tdi = pd.TimedeltaIndex(["-1 Day", "-1 Day", "-1 Day"])

0 commit comments

Comments
 (0)