Skip to content

Commit b116c10

Browse files
authored
REF: helpers to de-duplicate datetimelike arithmetic (#48862)
1 parent 5344107 commit b116c10

File tree

1 file changed

+35
-47
lines changed

1 file changed

+35
-47
lines changed

pandas/core/arrays/datetimelike.py

+35-47
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
Period,
3737
Resolution,
3838
Tick,
39+
Timedelta,
3940
Timestamp,
4041
astype_overflowsafe,
4142
delta_to_nanoseconds,
@@ -1122,7 +1123,7 @@ def _get_i8_values_and_mask(
11221123
if isinstance(other, Period):
11231124
i8values = other.ordinal
11241125
mask = None
1125-
elif isinstance(other, Timestamp):
1126+
elif isinstance(other, (Timestamp, Timedelta)):
11261127
i8values = other.value
11271128
mask = None
11281129
else:
@@ -1203,33 +1204,12 @@ def _sub_datetimelike_scalar(self, other: datetime | np.datetime64):
12031204
self = cast("DatetimeArray", self)
12041205
# subtract a datetime from myself, yielding a ndarray[timedelta64[ns]]
12051206

1206-
# error: Non-overlapping identity check (left operand type: "Union[datetime,
1207-
# datetime64]", right operand type: "NaTType") [comparison-overlap]
1208-
assert other is not NaT # type: ignore[comparison-overlap]
1209-
other = Timestamp(other)
1210-
# error: Non-overlapping identity check (left operand type: "Timestamp",
1211-
# right operand type: "NaTType")
1212-
if other is NaT: # type: ignore[comparison-overlap]
1207+
if isna(other):
1208+
# i.e. np.datetime64("NaT")
12131209
return self - NaT
12141210

1215-
try:
1216-
self._assert_tzawareness_compat(other)
1217-
except TypeError as err:
1218-
new_message = str(err).replace("compare", "subtract")
1219-
raise type(err)(new_message) from err
1220-
1221-
i8 = self.asi8
1222-
result = checked_add_with_arr(i8, -other.value, arr_mask=self._isnan)
1223-
res_m8 = result.view(f"timedelta64[{self._unit}]")
1224-
1225-
new_freq = None
1226-
if isinstance(self.freq, Tick):
1227-
# adding a scalar preserves freq
1228-
new_freq = self.freq
1229-
1230-
from pandas.core.arrays import TimedeltaArray
1231-
1232-
return TimedeltaArray._simple_new(res_m8, dtype=res_m8.dtype, freq=new_freq)
1211+
other = Timestamp(other)
1212+
return self._sub_datetimelike(other)
12331213

12341214
@final
12351215
def _sub_datetime_arraylike(self, other):
@@ -1241,19 +1221,28 @@ def _sub_datetime_arraylike(self, other):
12411221

12421222
self = cast("DatetimeArray", self)
12431223
other = ensure_wrapped_if_datetimelike(other)
1224+
return self._sub_datetimelike(other)
1225+
1226+
@final
1227+
def _sub_datetimelike(self, other: Timestamp | DatetimeArray) -> TimedeltaArray:
1228+
self = cast("DatetimeArray", self)
1229+
1230+
from pandas.core.arrays import TimedeltaArray
12441231

12451232
try:
12461233
self._assert_tzawareness_compat(other)
12471234
except TypeError as err:
12481235
new_message = str(err).replace("compare", "subtract")
12491236
raise type(err)(new_message) from err
12501237

1251-
self_i8 = self.asi8
1252-
other_i8 = other.asi8
1253-
new_values = checked_add_with_arr(
1254-
self_i8, -other_i8, arr_mask=self._isnan, b_mask=other._isnan
1238+
other_i8, o_mask = self._get_i8_values_and_mask(other)
1239+
res_values = checked_add_with_arr(
1240+
self.asi8, -other_i8, arr_mask=self._isnan, b_mask=o_mask
12551241
)
1256-
return new_values.view("timedelta64[ns]")
1242+
res_m8 = res_values.view(f"timedelta64[{self._unit}]")
1243+
1244+
new_freq = self._get_arithmetic_result_freq(other)
1245+
return TimedeltaArray._simple_new(res_m8, dtype=res_m8.dtype, freq=new_freq)
12571246

12581247
@final
12591248
def _sub_period(self, other: Period) -> npt.NDArray[np.object_]:
@@ -1289,24 +1278,15 @@ def _add_timedeltalike_scalar(self, other):
12891278
Same type as self
12901279
"""
12911280
if isna(other):
1292-
# i.e np.timedelta64("NaT"), not recognized by delta_to_nanoseconds
1281+
# i.e np.timedelta64("NaT")
12931282
new_values = np.empty(self.shape, dtype="i8").view(self._ndarray.dtype)
12941283
new_values.fill(iNaT)
12951284
return type(self)._simple_new(new_values, dtype=self.dtype)
12961285

12971286
# PeriodArray overrides, so we only get here with DTA/TDA
12981287
self = cast("DatetimeArray | TimedeltaArray", self)
1299-
inc = delta_to_nanoseconds(other, reso=self._reso)
1300-
1301-
new_values = checked_add_with_arr(self.asi8, inc, arr_mask=self._isnan)
1302-
new_values = new_values.view(self._ndarray.dtype)
1303-
1304-
new_freq = None
1305-
if isinstance(self.freq, Tick) or is_period_dtype(self.dtype):
1306-
# adding a scalar preserves freq
1307-
new_freq = self.freq
1308-
1309-
return type(self)._simple_new(new_values, dtype=self.dtype, freq=new_freq)
1288+
other = Timedelta(other)._as_unit(self._unit)
1289+
return self._add_timedeltalike(other)
13101290

13111291
def _add_timedelta_arraylike(
13121292
self, other: TimedeltaArray | npt.NDArray[np.timedelta64]
@@ -1334,13 +1314,21 @@ def _add_timedelta_arraylike(
13341314
else:
13351315
other = other._as_unit(self._unit)
13361316

1337-
self_i8 = self.asi8
1338-
other_i8 = other.asi8
1317+
return self._add_timedeltalike(other)
1318+
1319+
@final
1320+
def _add_timedeltalike(self, other: Timedelta | TimedeltaArray):
1321+
self = cast("DatetimeArray | TimedeltaArray", self)
1322+
1323+
other_i8, o_mask = self._get_i8_values_and_mask(other)
13391324
new_values = checked_add_with_arr(
1340-
self_i8, other_i8, arr_mask=self._isnan, b_mask=other._isnan
1325+
self.asi8, other_i8, arr_mask=self._isnan, b_mask=o_mask
13411326
)
13421327
res_values = new_values.view(self._ndarray.dtype)
1343-
return type(self)._simple_new(res_values, dtype=self.dtype)
1328+
1329+
new_freq = self._get_arithmetic_result_freq(other)
1330+
1331+
return type(self)._simple_new(res_values, dtype=self.dtype, freq=new_freq)
13441332

13451333
@final
13461334
def _add_nat(self):

0 commit comments

Comments
 (0)