Skip to content

Commit fded332

Browse files
jbrockmendelnickleus27
authored andcommitted
TST: make get_upcast_box more flexible (pandas-dev#44385)
1 parent 5694f78 commit fded332

File tree

5 files changed

+44
-45
lines changed

5 files changed

+44
-45
lines changed

pandas/_testing/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def box_expected(expected, box_cls, transpose=True):
259259
expected = DatetimeArray(expected)
260260
elif box_cls is TimedeltaArray:
261261
expected = TimedeltaArray(expected)
262-
elif box_cls is np.ndarray:
262+
elif box_cls is np.ndarray or box_cls is np.array:
263263
expected = np.array(expected)
264264
elif box_cls is to_array:
265265
expected = to_array(expected)

pandas/tests/arithmetic/common.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,29 @@ def assert_invalid_addsub_type(left, right, msg=None):
3434
right - left
3535

3636

37-
def get_expected_box(box):
37+
def get_upcast_box(left, right, is_cmp: bool = False):
3838
"""
39-
Get the box to use for 'expected' in a comparison operation.
40-
"""
41-
if box in [Index, array]:
42-
return np.ndarray
43-
return box
44-
39+
Get the box to use for 'expected' in an arithmetic or comparison operation.
4540
46-
def get_upcast_box(box, vector):
47-
"""
48-
Given two box-types, find the one that takes priority.
41+
Parameters
42+
left : Any
43+
right : Any
44+
is_cmp : bool, default False
45+
Whether the operation is a comparison method.
4946
"""
50-
if box is DataFrame or isinstance(vector, DataFrame):
47+
48+
if isinstance(left, DataFrame) or isinstance(right, DataFrame):
5149
return DataFrame
52-
if box is Series or isinstance(vector, Series):
50+
if isinstance(left, Series) or isinstance(right, Series):
51+
if is_cmp and isinstance(left, Index):
52+
# Index does not defer for comparisons
53+
return np.array
5354
return Series
54-
if box is Index or isinstance(vector, Index):
55+
if isinstance(left, Index) or isinstance(right, Index):
56+
if is_cmp:
57+
return np.array
5558
return Index
56-
return box
59+
return tm.to_array
5760

5861

5962
def assert_invalid_comparison(left, right, box):

pandas/tests/arithmetic/test_datetime64.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from pandas.tests.arithmetic.common import (
4444
assert_invalid_addsub_type,
4545
assert_invalid_comparison,
46-
get_expected_box,
4746
get_upcast_box,
4847
)
4948

@@ -60,12 +59,12 @@ def test_compare_zerodim(self, tz_naive_fixture, box_with_array):
6059
# Test comparison with zero-dimensional array is unboxed
6160
tz = tz_naive_fixture
6261
box = box_with_array
63-
xbox = get_expected_box(box)
6462
dti = date_range("20130101", periods=3, tz=tz)
6563

6664
other = np.array(dti.to_numpy()[0])
6765

6866
dtarr = tm.box_expected(dti, box)
67+
xbox = get_upcast_box(dtarr, other, True)
6968
result = dtarr <= other
7069
expected = np.array([True, False, False])
7170
expected = tm.box_expected(expected, xbox)
@@ -147,12 +146,12 @@ def test_dt64arr_nat_comparison(self, tz_naive_fixture, box_with_array):
147146
# GH#22242, GH#22163 DataFrame considered NaT == ts incorrectly
148147
tz = tz_naive_fixture
149148
box = box_with_array
150-
xbox = get_expected_box(box)
151149

152150
ts = Timestamp.now(tz)
153151
ser = Series([ts, NaT])
154152

155153
obj = tm.box_expected(ser, box)
154+
xbox = get_upcast_box(obj, ts, True)
156155

157156
expected = Series([True, False], dtype=np.bool_)
158157
expected = tm.box_expected(expected, xbox)
@@ -244,10 +243,9 @@ def test_nat_comparisons_scalar(self, dtype, data, box_with_array):
244243
# on older numpys (since they check object identity)
245244
return
246245

247-
xbox = get_expected_box(box)
248-
249246
left = Series(data, dtype=dtype)
250247
left = tm.box_expected(left, box)
248+
xbox = get_upcast_box(left, NaT, True)
251249

252250
expected = [False, False, False]
253251
expected = tm.box_expected(expected, xbox)
@@ -323,10 +321,10 @@ def test_timestamp_compare_series(self, left, right):
323321

324322
def test_dt64arr_timestamp_equality(self, box_with_array):
325323
# GH#11034
326-
xbox = get_expected_box(box_with_array)
327324

328325
ser = Series([Timestamp("2000-01-29 01:59:00"), Timestamp("2000-01-30"), NaT])
329326
ser = tm.box_expected(ser, box_with_array)
327+
xbox = get_upcast_box(ser, ser, True)
330328

331329
result = ser != ser
332330
expected = tm.box_expected([False, False, True], xbox)
@@ -417,13 +415,12 @@ def test_dti_cmp_nat(self, dtype, box_with_array):
417415
# on older numpys (since they check object identity)
418416
return
419417

420-
xbox = get_expected_box(box_with_array)
421-
422418
left = DatetimeIndex([Timestamp("2011-01-01"), NaT, Timestamp("2011-01-03")])
423419
right = DatetimeIndex([NaT, NaT, Timestamp("2011-01-03")])
424420

425421
left = tm.box_expected(left, box_with_array)
426422
right = tm.box_expected(right, box_with_array)
423+
xbox = get_upcast_box(left, right, True)
427424

428425
lhs, rhs = left, right
429426
if dtype is object:
@@ -642,12 +639,11 @@ def test_scalar_comparison_tzawareness(
642639
self, comparison_op, other, tz_aware_fixture, box_with_array
643640
):
644641
op = comparison_op
645-
box = box_with_array
646642
tz = tz_aware_fixture
647643
dti = date_range("2016-01-01", periods=2, tz=tz)
648-
xbox = get_expected_box(box)
649644

650645
dtarr = tm.box_expected(dti, box_with_array)
646+
xbox = get_upcast_box(dtarr, other, True)
651647
if op in [operator.eq, operator.ne]:
652648
exbool = op is operator.ne
653649
expected = np.array([exbool, exbool], dtype=bool)
@@ -2421,14 +2417,13 @@ def test_dti_addsub_offset_arraylike(
24212417
self, tz_naive_fixture, names, op, index_or_series
24222418
):
24232419
# GH#18849, GH#19744
2424-
box = pd.Index
24252420
other_box = index_or_series
24262421

24272422
tz = tz_naive_fixture
24282423
dti = date_range("2017-01-01", periods=2, tz=tz, name=names[0])
24292424
other = other_box([pd.offsets.MonthEnd(), pd.offsets.Day(n=2)], name=names[1])
24302425

2431-
xbox = get_upcast_box(box, other)
2426+
xbox = get_upcast_box(dti, other)
24322427

24332428
with tm.assert_produces_warning(PerformanceWarning):
24342429
res = op(dti, other)
@@ -2448,7 +2443,7 @@ def test_dti_addsub_object_arraylike(
24482443
dti = date_range("2017-01-01", periods=2, tz=tz)
24492444
dtarr = tm.box_expected(dti, box_with_array)
24502445
other = other_box([pd.offsets.MonthEnd(), Timedelta(days=4)])
2451-
xbox = get_upcast_box(box_with_array, other)
2446+
xbox = get_upcast_box(dtarr, other)
24522447

24532448
expected = DatetimeIndex(["2017-01-31", "2017-01-06"], tz=tz_naive_fixture)
24542449
expected = tm.box_expected(expected, xbox)

pandas/tests/arithmetic/test_period.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pandas.core.arrays import TimedeltaArray
2828
from pandas.tests.arithmetic.common import (
2929
assert_invalid_comparison,
30-
get_expected_box,
30+
get_upcast_box,
3131
)
3232

3333
# ------------------------------------------------------------------
@@ -41,12 +41,13 @@ class TestPeriodArrayLikeComparisons:
4141

4242
def test_compare_zerodim(self, box_with_array):
4343
# GH#26689 make sure we unbox zero-dimensional arrays
44-
xbox = get_expected_box(box_with_array)
4544

4645
pi = period_range("2000", periods=4)
4746
other = np.array(pi.to_numpy()[0])
4847

4948
pi = tm.box_expected(pi, box_with_array)
49+
xbox = get_upcast_box(pi, other, True)
50+
5051
result = pi <= other
5152
expected = np.array([True, False, False, False])
5253
expected = tm.box_expected(expected, xbox)
@@ -78,11 +79,11 @@ def test_compare_invalid_listlike(self, box_with_array, other):
7879

7980
@pytest.mark.parametrize("other_box", [list, np.array, lambda x: x.astype(object)])
8081
def test_compare_object_dtype(self, box_with_array, other_box):
81-
xbox = get_expected_box(box_with_array)
8282
pi = period_range("2000", periods=5)
8383
parr = tm.box_expected(pi, box_with_array)
8484

8585
other = other_box(pi)
86+
xbox = get_upcast_box(parr, other, True)
8687

8788
expected = np.array([True, True, True, True, True])
8889
expected = tm.box_expected(expected, xbox)
@@ -195,14 +196,15 @@ def test_pi_cmp_period(self):
195196

196197
# TODO: moved from test_datetime64; de-duplicate with version below
197198
def test_parr_cmp_period_scalar2(self, box_with_array):
198-
xbox = get_expected_box(box_with_array)
199-
200199
pi = period_range("2000-01-01", periods=10, freq="D")
201200

202201
val = Period("2000-01-04", freq="D")
202+
203203
expected = [x > val for x in pi]
204204

205205
ser = tm.box_expected(pi, box_with_array)
206+
xbox = get_upcast_box(ser, val, True)
207+
206208
expected = tm.box_expected(expected, xbox)
207209
result = ser > val
208210
tm.assert_equal(result, expected)
@@ -216,11 +218,10 @@ def test_parr_cmp_period_scalar2(self, box_with_array):
216218
@pytest.mark.parametrize("freq", ["M", "2M", "3M"])
217219
def test_parr_cmp_period_scalar(self, freq, box_with_array):
218220
# GH#13200
219-
xbox = get_expected_box(box_with_array)
220-
221221
base = PeriodIndex(["2011-01", "2011-02", "2011-03", "2011-04"], freq=freq)
222222
base = tm.box_expected(base, box_with_array)
223223
per = Period("2011-02", freq=freq)
224+
xbox = get_upcast_box(base, per, True)
224225

225226
exp = np.array([False, True, False, False])
226227
exp = tm.box_expected(exp, xbox)
@@ -255,14 +256,14 @@ def test_parr_cmp_period_scalar(self, freq, box_with_array):
255256
@pytest.mark.parametrize("freq", ["M", "2M", "3M"])
256257
def test_parr_cmp_pi(self, freq, box_with_array):
257258
# GH#13200
258-
xbox = get_expected_box(box_with_array)
259-
260259
base = PeriodIndex(["2011-01", "2011-02", "2011-03", "2011-04"], freq=freq)
261260
base = tm.box_expected(base, box_with_array)
262261

263262
# TODO: could also box idx?
264263
idx = PeriodIndex(["2011-02", "2011-01", "2011-03", "2011-05"], freq=freq)
265264

265+
xbox = get_upcast_box(base, idx, True)
266+
266267
exp = np.array([False, False, True, False])
267268
exp = tm.box_expected(exp, xbox)
268269
tm.assert_equal(base == idx, exp)

pandas/tests/arithmetic/test_timedelta64.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1542,13 +1542,13 @@ def test_tdi_mul_float_series(self, box_with_array):
15421542
)
15431543
def test_tdi_rmul_arraylike(self, other, box_with_array):
15441544
box = box_with_array
1545-
xbox = get_upcast_box(box, other)
15461545

15471546
tdi = TimedeltaIndex(["1 Day"] * 10)
1548-
expected = timedelta_range("1 days", "10 days")
1549-
expected._data.freq = None
1547+
expected = timedelta_range("1 days", "10 days")._with_freq(None)
15501548

15511549
tdi = tm.box_expected(tdi, box)
1550+
xbox = get_upcast_box(tdi, other)
1551+
15521552
expected = tm.box_expected(expected, xbox)
15531553

15541554
result = other * tdi
@@ -2000,14 +2000,15 @@ def test_td64arr_rmul_numeric_array(
20002000
):
20012001
# GH#4521
20022002
# divide/multiply by integers
2003-
xbox = get_upcast_box(box_with_array, vector)
20042003

20052004
tdser = Series(["59 Days", "59 Days", "NaT"], dtype="m8[ns]")
20062005
vector = vector.astype(any_real_numpy_dtype)
20072006

20082007
expected = Series(["1180 Days", "1770 Days", "NaT"], dtype="timedelta64[ns]")
20092008

20102009
tdser = tm.box_expected(tdser, box_with_array)
2010+
xbox = get_upcast_box(tdser, vector)
2011+
20112012
expected = tm.box_expected(expected, xbox)
20122013

20132014
result = tdser * vector
@@ -2026,14 +2027,14 @@ def test_td64arr_div_numeric_array(
20262027
):
20272028
# GH#4521
20282029
# divide/multiply by integers
2029-
xbox = get_upcast_box(box_with_array, vector)
20302030

20312031
tdser = Series(["59 Days", "59 Days", "NaT"], dtype="m8[ns]")
20322032
vector = vector.astype(any_real_numpy_dtype)
20332033

20342034
expected = Series(["2.95D", "1D 23H 12m", "NaT"], dtype="timedelta64[ns]")
20352035

20362036
tdser = tm.box_expected(tdser, box_with_array)
2037+
xbox = get_upcast_box(tdser, vector)
20372038
expected = tm.box_expected(expected, xbox)
20382039

20392040
result = tdser / vector
@@ -2085,7 +2086,7 @@ def test_td64arr_mul_int_series(self, box_with_array, names):
20852086
)
20862087

20872088
tdi = tm.box_expected(tdi, box)
2088-
xbox = get_upcast_box(box, ser)
2089+
xbox = get_upcast_box(tdi, ser)
20892090

20902091
expected = tm.box_expected(expected, xbox)
20912092

@@ -2117,9 +2118,8 @@ def test_float_series_rdiv_td64arr(self, box_with_array, names):
21172118
name=xname,
21182119
)
21192120

2120-
xbox = get_upcast_box(box, ser)
2121-
21222121
tdi = tm.box_expected(tdi, box)
2122+
xbox = get_upcast_box(tdi, ser)
21232123
expected = tm.box_expected(expected, xbox)
21242124

21252125
result = ser.__rtruediv__(tdi)

0 commit comments

Comments
 (0)