Skip to content

Commit 5a7b5c9

Browse files
jbrockmendeljreback
authored andcommitted
TST: collect arithmetic test helpers (#30354)
1 parent 00a6c9e commit 5a7b5c9

File tree

3 files changed

+97
-81
lines changed

3 files changed

+97
-81
lines changed

pandas/tests/arithmetic/common.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
Assertion helpers for arithmetic tests.
3+
"""
4+
import numpy as np
5+
import pytest
6+
7+
from pandas import DataFrame, Index, Series
8+
import pandas.util.testing as tm
9+
10+
11+
def assert_invalid_addsub_type(left, right, msg=None):
12+
"""
13+
Helper to assert that left and right can be neither added nor subtracted.
14+
15+
Parameters
16+
---------
17+
left : object
18+
right : object
19+
msg : str or None, default None
20+
"""
21+
with pytest.raises(TypeError, match=msg):
22+
left + right
23+
with pytest.raises(TypeError, match=msg):
24+
right + left
25+
with pytest.raises(TypeError, match=msg):
26+
left - right
27+
with pytest.raises(TypeError, match=msg):
28+
right - left
29+
30+
31+
def get_upcast_box(box, vector):
32+
"""
33+
Given two box-types, find the one that takes priority
34+
"""
35+
if box is DataFrame or isinstance(vector, DataFrame):
36+
return DataFrame
37+
if box is Series or isinstance(vector, Series):
38+
return Series
39+
if box is Index or isinstance(vector, Index):
40+
return Index
41+
return box
42+
43+
44+
def assert_invalid_comparison(left, right, box):
45+
"""
46+
Assert that comparison operations with mismatched types behave correctly.
47+
48+
Parameters
49+
----------
50+
left : np.ndarray, ExtensionArray, Index, or Series
51+
right : object
52+
box : {pd.DataFrame, pd.Series, pd.Index, tm.to_array}
53+
"""
54+
# Not for tznaive-tzaware comparison
55+
56+
# Note: not quite the same as how we do this for tm.box_expected
57+
xbox = box if box is not Index else np.array
58+
59+
result = left == right
60+
expected = xbox(np.zeros(result.shape, dtype=np.bool_))
61+
62+
tm.assert_equal(result, expected)
63+
64+
result = right == left
65+
tm.assert_equal(result, expected)
66+
67+
result = left != right
68+
tm.assert_equal(result, ~expected)
69+
70+
result = right != left
71+
tm.assert_equal(result, ~expected)
72+
73+
msg = "Invalid comparison between"
74+
with pytest.raises(TypeError, match=msg):
75+
left < right
76+
with pytest.raises(TypeError, match=msg):
77+
left <= right
78+
with pytest.raises(TypeError, match=msg):
79+
left > right
80+
with pytest.raises(TypeError, match=msg):
81+
left >= right
82+
with pytest.raises(TypeError, match=msg):
83+
right < left
84+
with pytest.raises(TypeError, match=msg):
85+
right <= left
86+
with pytest.raises(TypeError, match=msg):
87+
right > left
88+
with pytest.raises(TypeError, match=msg):
89+
right >= left

pandas/tests/arithmetic/test_datetime64.py

+7-66
Original file line numberDiff line numberDiff line change
@@ -29,57 +29,13 @@
2929
import pandas.core.arrays.datetimelike as dtl
3030
from pandas.core.indexes.datetimes import _to_M8
3131
from pandas.core.ops import roperator
32+
from pandas.tests.arithmetic.common import (
33+
assert_invalid_addsub_type,
34+
assert_invalid_comparison,
35+
get_upcast_box,
36+
)
3237
import pandas.util.testing as tm
3338

34-
35-
def assert_invalid_comparison(left, right, box):
36-
"""
37-
Assert that comparison operations with mismatched types behave correctly.
38-
39-
Parameters
40-
----------
41-
left : np.ndarray, ExtensionArray, Index, or Series
42-
right : object
43-
box : {pd.DataFrame, pd.Series, pd.Index, tm.to_array}
44-
"""
45-
# Not for tznaive-tzaware comparison
46-
47-
# Note: not quite the same as how we do this for tm.box_expected
48-
xbox = box if box is not pd.Index else np.array
49-
50-
result = left == right
51-
expected = xbox(np.zeros(result.shape, dtype=np.bool_))
52-
53-
tm.assert_equal(result, expected)
54-
55-
result = right == left
56-
tm.assert_equal(result, expected)
57-
58-
result = left != right
59-
tm.assert_equal(result, ~expected)
60-
61-
result = right != left
62-
tm.assert_equal(result, ~expected)
63-
64-
msg = "Invalid comparison between"
65-
with pytest.raises(TypeError, match=msg):
66-
left < right
67-
with pytest.raises(TypeError, match=msg):
68-
left <= right
69-
with pytest.raises(TypeError, match=msg):
70-
left > right
71-
with pytest.raises(TypeError, match=msg):
72-
left >= right
73-
with pytest.raises(TypeError, match=msg):
74-
right < left
75-
with pytest.raises(TypeError, match=msg):
76-
right <= left
77-
with pytest.raises(TypeError, match=msg):
78-
right > left
79-
with pytest.raises(TypeError, match=msg):
80-
right >= left
81-
82-
8339
# ------------------------------------------------------------------
8440
# Comparisons
8541

@@ -1033,14 +989,7 @@ def test_dt64arr_add_sub_invalid(self, dti_freq, other, box_with_array):
1033989
"ufunc '?(add|subtract)'? cannot use operands with types",
1034990
]
1035991
)
1036-
with pytest.raises(TypeError, match=msg):
1037-
dtarr + other
1038-
with pytest.raises(TypeError, match=msg):
1039-
other + dtarr
1040-
with pytest.raises(TypeError, match=msg):
1041-
dtarr - other
1042-
with pytest.raises(TypeError, match=msg):
1043-
other - dtarr
992+
assert_invalid_addsub_type(dtarr, other, msg)
1044993

1045994
@pytest.mark.parametrize("pi_freq", ["D", "W", "Q", "H"])
1046995
@pytest.mark.parametrize("dti_freq", [None, "D"])
@@ -1061,14 +1010,7 @@ def test_dt64arr_add_sub_parr(
10611010
"ufunc.*cannot use operands",
10621011
]
10631012
)
1064-
with pytest.raises(TypeError, match=msg):
1065-
dtarr + parr
1066-
with pytest.raises(TypeError, match=msg):
1067-
parr + dtarr
1068-
with pytest.raises(TypeError, match=msg):
1069-
dtarr - parr
1070-
with pytest.raises(TypeError, match=msg):
1071-
parr - dtarr
1013+
assert_invalid_addsub_type(dtarr, parr, msg)
10721014

10731015

10741016
class TestDatetime64DateOffsetArithmetic:
@@ -2368,7 +2310,6 @@ def test_dti_addsub_offset_arraylike(
23682310
# GH#18849, GH#19744
23692311
box = pd.Index
23702312
other_box = index_or_series
2371-
from .test_timedelta64 import get_upcast_box
23722313

23732314
tz = tz_naive_fixture
23742315
dti = pd.date_range("2017-01-01", periods=2, tz=tz, name=names[0])

pandas/tests/arithmetic/test_timedelta64.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,9 @@
1818
Timestamp,
1919
timedelta_range,
2020
)
21-
from pandas.tests.arithmetic.test_datetime64 import assert_invalid_comparison
21+
from pandas.tests.arithmetic.common import assert_invalid_comparison, get_upcast_box
2222
import pandas.util.testing as tm
2323

24-
25-
def get_upcast_box(box, vector):
26-
"""
27-
Given two box-types, find the one that takes priority
28-
"""
29-
if box is DataFrame or isinstance(vector, DataFrame):
30-
return DataFrame
31-
if box is Series or isinstance(vector, Series):
32-
return Series
33-
if box is pd.Index or isinstance(vector, pd.Index):
34-
return pd.Index
35-
return box
36-
37-
3824
# ------------------------------------------------------------------
3925
# Timedelta64[ns] dtype Comparisons
4026

0 commit comments

Comments
 (0)