Skip to content

Commit 9004a5c

Browse files
jbrockmendelnoatamir
authored andcommitted
REF: tighter typing in datetimelike arith methods (pandas-dev#48962)
1 parent d3ec846 commit 9004a5c

File tree

3 files changed

+26
-40
lines changed

3 files changed

+26
-40
lines changed

pandas/core/arrays/datetimelike.py

+24-38
Original file line numberDiff line numberDiff line change
@@ -1160,33 +1160,36 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
11601160
from pandas.core.arrays.datetimes import tz_to_dtype
11611161

11621162
assert other is not NaT
1163-
other = Timestamp(other)
1164-
if other is NaT:
1163+
if isna(other):
1164+
# i.e. np.datetime64("NaT")
11651165
# In this case we specifically interpret NaT as a datetime, not
11661166
# the timedelta interpretation we would get by returning self + NaT
11671167
result = self._ndarray + NaT.to_datetime64().astype(f"M8[{self._unit}]")
11681168
# Preserve our resolution
11691169
return DatetimeArray._simple_new(result, dtype=result.dtype)
11701170

1171+
other = Timestamp(other)
11711172
self, other = self._ensure_matching_resos(other)
11721173
self = cast("TimedeltaArray", self)
11731174

1174-
i8 = self.asi8
1175-
result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan)
1175+
other_i8, o_mask = self._get_i8_values_and_mask(other)
1176+
result = checked_add_with_arr(
1177+
self.asi8, other_i8, arr_mask=self._isnan, b_mask=o_mask
1178+
)
1179+
res_values = result.view(f"M8[{self._unit}]")
11761180

11771181
dtype = tz_to_dtype(tz=other.tz, unit=self._unit)
11781182
res_values = result.view(f"M8[{self._unit}]")
1179-
return DatetimeArray._simple_new(res_values, dtype=dtype, freq=self.freq)
1183+
new_freq = self._get_arithmetic_result_freq(other)
1184+
return DatetimeArray._simple_new(res_values, dtype=dtype, freq=new_freq)
11801185

11811186
@final
1182-
def _add_datetime_arraylike(self, other) -> DatetimeArray:
1187+
def _add_datetime_arraylike(self, other: DatetimeArray) -> DatetimeArray:
11831188
if not is_timedelta64_dtype(self.dtype):
11841189
raise TypeError(
11851190
f"cannot add {type(self).__name__} and {type(other).__name__}"
11861191
)
11871192

1188-
# At this point we have already checked that other.dtype is datetime64
1189-
other = ensure_wrapped_if_datetimelike(other)
11901193
# defer to DatetimeArray.__add__
11911194
return other + self
11921195

@@ -1208,15 +1211,14 @@ def _sub_datetimelike_scalar(self, other: datetime | np.datetime64):
12081211
return self._sub_datetimelike(ts)
12091212

12101213
@final
1211-
def _sub_datetime_arraylike(self, other):
1214+
def _sub_datetime_arraylike(self, other: DatetimeArray):
12121215
if self.dtype.kind != "M":
12131216
raise TypeError(f"cannot subtract a datelike from a {type(self).__name__}")
12141217

12151218
if len(self) != len(other):
12161219
raise ValueError("cannot add indices of unequal length")
12171220

12181221
self = cast("DatetimeArray", self)
1219-
other = ensure_wrapped_if_datetimelike(other)
12201222

12211223
self, other = self._ensure_matching_resos(other)
12221224
return self._sub_datetimelike(other)
@@ -1242,16 +1244,6 @@ def _sub_datetimelike(self, other: Timestamp | DatetimeArray) -> TimedeltaArray:
12421244
new_freq = self._get_arithmetic_result_freq(other)
12431245
return TimedeltaArray._simple_new(res_m8, dtype=res_m8.dtype, freq=new_freq)
12441246

1245-
@final
1246-
def _sub_period(self, other: Period) -> npt.NDArray[np.object_]:
1247-
if not is_period_dtype(self.dtype):
1248-
raise TypeError(f"cannot subtract Period from a {type(self).__name__}")
1249-
1250-
# If the operation is well-defined, we return an object-dtype ndarray
1251-
# of DateOffsets. Null entries are filled with pd.NaT
1252-
self._check_compatible_with(other)
1253-
return self._sub_periodlike(other)
1254-
12551247
@final
12561248
def _add_period(self, other: Period) -> PeriodArray:
12571249
if not is_timedelta64_dtype(self.dtype):
@@ -1286,9 +1278,7 @@ def _add_timedeltalike_scalar(self, other):
12861278
other = Timedelta(other)._as_unit(self._unit)
12871279
return self._add_timedeltalike(other)
12881280

1289-
def _add_timedelta_arraylike(
1290-
self, other: TimedeltaArray | npt.NDArray[np.timedelta64]
1291-
):
1281+
def _add_timedelta_arraylike(self, other: TimedeltaArray):
12921282
"""
12931283
Add a delta of a TimedeltaIndex
12941284
@@ -1301,12 +1291,10 @@ def _add_timedelta_arraylike(
13011291
if len(self) != len(other):
13021292
raise ValueError("cannot add indices of unequal length")
13031293

1304-
other = ensure_wrapped_if_datetimelike(other)
1305-
tda = cast("TimedeltaArray", other)
13061294
self = cast("DatetimeArray | TimedeltaArray", self)
13071295

1308-
self, tda = self._ensure_matching_resos(tda)
1309-
return self._add_timedeltalike(tda)
1296+
self, other = self._ensure_matching_resos(other)
1297+
return self._add_timedeltalike(other)
13101298

13111299
@final
13121300
def _add_timedeltalike(self, other: Timedelta | TimedeltaArray):
@@ -1356,21 +1344,17 @@ def _sub_nat(self):
13561344
return result.view("timedelta64[ns]")
13571345

13581346
@final
1359-
def _sub_period_array(self, other: PeriodArray) -> npt.NDArray[np.object_]:
1347+
def _sub_periodlike(self, other: Period | PeriodArray) -> npt.NDArray[np.object_]:
1348+
# If the operation is well-defined, we return an object-dtype ndarray
1349+
# of DateOffsets. Null entries are filled with pd.NaT
13601350
if not is_period_dtype(self.dtype):
13611351
raise TypeError(
1362-
f"cannot subtract {other.dtype}-dtype from {type(self).__name__}"
1352+
f"cannot subtract {type(other).__name__} from {type(self).__name__}"
13631353
)
13641354

13651355
self = cast("PeriodArray", self)
1366-
self._require_matching_freq(other)
1367-
1368-
return self._sub_periodlike(other)
1356+
self._check_compatible_with(other)
13691357

1370-
@final
1371-
def _sub_periodlike(self, other: Period | PeriodArray) -> npt.NDArray[np.object_]:
1372-
# caller is responsible for calling
1373-
# require_matching_freq/check_compatible_with
13741358
other_i8, o_mask = self._get_i8_values_and_mask(other)
13751359
new_i8_data = checked_add_with_arr(
13761360
self.asi8, -other_i8, arr_mask=self._isnan, b_mask=o_mask
@@ -1465,6 +1449,7 @@ def _time_shift(
14651449
@unpack_zerodim_and_defer("__add__")
14661450
def __add__(self, other):
14671451
other_dtype = getattr(other, "dtype", None)
1452+
other = ensure_wrapped_if_datetimelike(other)
14681453

14691454
# scalar others
14701455
if other is NaT:
@@ -1525,6 +1510,7 @@ def __radd__(self, other):
15251510
def __sub__(self, other):
15261511

15271512
other_dtype = getattr(other, "dtype", None)
1513+
other = ensure_wrapped_if_datetimelike(other)
15281514

15291515
# scalar others
15301516
if other is NaT:
@@ -1546,7 +1532,7 @@ def __sub__(self, other):
15461532
)
15471533

15481534
elif isinstance(other, Period):
1549-
result = self._sub_period(other)
1535+
result = self._sub_periodlike(other)
15501536

15511537
# array-like others
15521538
elif is_timedelta64_dtype(other_dtype):
@@ -1560,7 +1546,7 @@ def __sub__(self, other):
15601546
result = self._sub_datetime_arraylike(other)
15611547
elif is_period_dtype(other_dtype):
15621548
# PeriodIndex
1563-
result = self._sub_period_array(other)
1549+
result = self._sub_periodlike(other)
15641550
elif is_integer_dtype(other_dtype):
15651551
if not is_period_dtype(self.dtype):
15661552
raise integer_op_not_supported(self)

pandas/tests/arithmetic/test_numeric.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ def test_datetime64_with_index(self):
911911
result = ser - ser.index
912912
tm.assert_series_equal(result, expected)
913913

914-
msg = "cannot subtract period"
914+
msg = "cannot subtract PeriodArray from DatetimeArray"
915915
with pytest.raises(TypeError, match=msg):
916916
# GH#18850
917917
result = ser - ser.index.to_period()

pandas/tests/arithmetic/test_period.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def test_pi_add_sub_td64_array_non_tick_raises(self):
745745

746746
with pytest.raises(TypeError, match=msg):
747747
rng - tdarr
748-
msg = r"cannot subtract period\[Q-DEC\]-dtype from TimedeltaArray"
748+
msg = r"cannot subtract PeriodArray from TimedeltaArray"
749749
with pytest.raises(TypeError, match=msg):
750750
tdarr - rng
751751

0 commit comments

Comments
 (0)