Skip to content

Commit 3255b03

Browse files
jbrockmendelquintusdias
authored andcommitted
BUG: fix+test DTA/TDA/PA add/sub Index (pandas-dev#27726)
1 parent 6731d4b commit 3255b03

File tree

4 files changed

+52
-3
lines changed

4 files changed

+52
-3
lines changed

pandas/core/arrays/datetimelike.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,7 @@ def _time_shift(self, periods, freq=None):
12071207

12081208
def __add__(self, other):
12091209
other = lib.item_from_zerodim(other)
1210-
if isinstance(other, (ABCSeries, ABCDataFrame)):
1210+
if isinstance(other, (ABCSeries, ABCDataFrame, ABCIndexClass)):
12111211
return NotImplemented
12121212

12131213
# scalar others
@@ -1273,7 +1273,7 @@ def __radd__(self, other):
12731273

12741274
def __sub__(self, other):
12751275
other = lib.item_from_zerodim(other)
1276-
if isinstance(other, (ABCSeries, ABCDataFrame)):
1276+
if isinstance(other, (ABCSeries, ABCDataFrame, ABCIndexClass)):
12771277
return NotImplemented
12781278

12791279
# scalar others
@@ -1340,7 +1340,7 @@ def __sub__(self, other):
13401340
return result
13411341

13421342
def __rsub__(self, other):
1343-
if is_datetime64_dtype(other) and is_timedelta64_dtype(self):
1343+
if is_datetime64_any_dtype(other) and is_timedelta64_dtype(self):
13441344
# ndarray[datetime64] cannot be subtracted from self, so
13451345
# we need to wrap in DatetimeArray/Index and flip the operation
13461346
if not isinstance(other, DatetimeLikeArrayMixin):

pandas/tests/arithmetic/test_datetime64.py

+18
Original file line numberDiff line numberDiff line change
@@ -2249,6 +2249,23 @@ def test_add_datetimelike_and_dti(self, addend, tz):
22492249

22502250
# -------------------------------------------------------------
22512251

2252+
def test_dta_add_sub_index(self, tz_naive_fixture):
2253+
# Check that DatetimeArray defers to Index classes
2254+
dti = date_range("20130101", periods=3, tz=tz_naive_fixture)
2255+
dta = dti.array
2256+
result = dta - dti
2257+
expected = dti - dti
2258+
tm.assert_index_equal(result, expected)
2259+
2260+
tdi = result
2261+
result = dta + tdi
2262+
expected = dti + tdi
2263+
tm.assert_index_equal(result, expected)
2264+
2265+
result = dta - tdi
2266+
expected = dti - tdi
2267+
tm.assert_index_equal(result, expected)
2268+
22522269
def test_sub_dti_dti(self):
22532270
# previously performed setop (deprecated in 0.16.0), now changed to
22542271
# return subtraction -> TimeDeltaIndex (GH ...)
@@ -2554,6 +2571,7 @@ def test_shift_months(years, months):
25542571
tm.assert_index_equal(actual, expected)
25552572

25562573

2574+
# FIXME: this belongs in scalar tests
25572575
class SubDatetime(datetime):
25582576
pass
25592577

pandas/tests/arithmetic/test_period.py

+12
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,18 @@ def test_parr_add_sub_tdt64_nat_array(self, box_df_fail, other):
10411041
with pytest.raises(TypeError):
10421042
other - obj
10431043

1044+
# ---------------------------------------------------------------
1045+
# Unsorted
1046+
1047+
def test_parr_add_sub_index(self):
1048+
# Check that PeriodArray defers to Index on arithmetic ops
1049+
pi = pd.period_range("2000-12-31", periods=3)
1050+
parr = pi.array
1051+
1052+
result = parr - pi
1053+
expected = pi - pi
1054+
tm.assert_index_equal(result, expected)
1055+
10441056

10451057
class TestPeriodSeriesArithmetic:
10461058
def test_ops_series_timedelta(self):

pandas/tests/arithmetic/test_timedelta64.py

+19
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,25 @@ def test_timedelta(self, freq):
480480
tm.assert_index_equal(result1, result4)
481481
tm.assert_index_equal(result2, result3)
482482

483+
def test_tda_add_sub_index(self):
484+
# Check that TimedeltaArray defers to Index on arithmetic ops
485+
tdi = TimedeltaIndex(["1 days", pd.NaT, "2 days"])
486+
tda = tdi.array
487+
488+
dti = pd.date_range("1999-12-31", periods=3, freq="D")
489+
490+
result = tda + dti
491+
expected = tdi + dti
492+
tm.assert_index_equal(result, expected)
493+
494+
result = tda + tdi
495+
expected = tdi + tdi
496+
tm.assert_index_equal(result, expected)
497+
498+
result = tda - tdi
499+
expected = tdi - tdi
500+
tm.assert_index_equal(result, expected)
501+
483502

484503
class TestAddSubNaTMasking:
485504
# TODO: parametrize over boxes

0 commit comments

Comments
 (0)