Skip to content

Commit 9c7bb7f

Browse files
committed
REF: tighter typing in datetimelike arith methods
1 parent d8e7651 commit 9c7bb7f

File tree

3 files changed

+25
-38
lines changed

3 files changed

+25
-38
lines changed

pandas/core/arrays/datetimelike.py

+23-36
Original file line numberDiff line numberDiff line change
@@ -1161,14 +1161,16 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
11611161
from pandas.core.arrays.datetimes import tz_to_dtype
11621162

11631163
assert other is not NaT
1164-
other = Timestamp(other)
1165-
if other is NaT:
1164+
if isna(other):
1165+
# i.e. np.datetime64("NaT")
11661166
# In this case we specifically interpret NaT as a datetime, not
11671167
# the timedelta interpretation we would get by returning self + NaT
11681168
result = self._ndarray + NaT.to_datetime64().astype(f"M8[{self._unit}]")
11691169
# Preserve our resolution
11701170
return DatetimeArray._simple_new(result, dtype=result.dtype)
11711171

1172+
other = Timestamp(other)
1173+
11721174
if self._reso != other._reso:
11731175
# Just as with Timestamp/Timedelta, we cast to the higher resolution
11741176
if self._reso < other._reso:
@@ -1177,22 +1179,24 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
11771179
else:
11781180
other = other._as_unit(self._unit)
11791181

1180-
i8 = self.asi8
1181-
result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan)
1182+
other_i8, o_mask = self._get_i8_values_and_mask(other)
1183+
result = checked_add_with_arr(
1184+
self.asi8, other_i8, arr_mask=self._isnan, b_mask=o_mask
1185+
)
1186+
res_values = result.view(f"M8[{self._unit}]")
11821187

11831188
dtype = tz_to_dtype(tz=other.tz, unit=self._unit)
11841189
res_values = result.view(f"M8[{self._unit}]")
1185-
return DatetimeArray._simple_new(res_values, dtype=dtype, freq=self.freq)
1190+
new_freq = self._get_arithmetic_result_freq(other)
1191+
return DatetimeArray._simple_new(res_values, dtype=dtype, freq=new_freq)
11861192

11871193
@final
1188-
def _add_datetime_arraylike(self, other) -> DatetimeArray:
1194+
def _add_datetime_arraylike(self, other: DatetimeArray) -> DatetimeArray:
11891195
if not is_timedelta64_dtype(self.dtype):
11901196
raise TypeError(
11911197
f"cannot add {type(self).__name__} and {type(other).__name__}"
11921198
)
11931199

1194-
# At this point we have already checked that other.dtype is datetime64
1195-
other = ensure_wrapped_if_datetimelike(other)
11961200
# defer to DatetimeArray.__add__
11971201
return other + self
11981202

@@ -1220,15 +1224,14 @@ def _sub_datetimelike_scalar(self, other: datetime | np.datetime64):
12201224
return self._sub_datetimelike(other)
12211225

12221226
@final
1223-
def _sub_datetime_arraylike(self, other):
1227+
def _sub_datetime_arraylike(self, other: DatetimeArray):
12241228
if self.dtype.kind != "M":
12251229
raise TypeError(f"cannot subtract a datelike from a {type(self).__name__}")
12261230

12271231
if len(self) != len(other):
12281232
raise ValueError("cannot add indices of unequal length")
12291233

12301234
self = cast("DatetimeArray", self)
1231-
other = ensure_wrapped_if_datetimelike(other)
12321235

12331236
if other._reso != self._reso:
12341237
if other._reso < self._reso:
@@ -1259,16 +1262,6 @@ def _sub_datetimelike(self, other: Timestamp | DatetimeArray) -> TimedeltaArray:
12591262
new_freq = self._get_arithmetic_result_freq(other)
12601263
return TimedeltaArray._simple_new(res_m8, dtype=res_m8.dtype, freq=new_freq)
12611264

1262-
@final
1263-
def _sub_period(self, other: Period) -> npt.NDArray[np.object_]:
1264-
if not is_period_dtype(self.dtype):
1265-
raise TypeError(f"cannot subtract Period from a {type(self).__name__}")
1266-
1267-
# If the operation is well-defined, we return an object-dtype ndarray
1268-
# of DateOffsets. Null entries are filled with pd.NaT
1269-
self._check_compatible_with(other)
1270-
return self._sub_periodlike(other)
1271-
12721265
@final
12731266
def _add_period(self, other: Period) -> PeriodArray:
12741267
if not is_timedelta64_dtype(self.dtype):
@@ -1303,9 +1296,7 @@ def _add_timedeltalike_scalar(self, other):
13031296
other = Timedelta(other)._as_unit(self._unit)
13041297
return self._add_timedeltalike(other)
13051298

1306-
def _add_timedelta_arraylike(
1307-
self, other: TimedeltaArray | npt.NDArray[np.timedelta64]
1308-
):
1299+
def _add_timedelta_arraylike(self, other: TimedeltaArray):
13091300
"""
13101301
Add a delta of a TimedeltaIndex
13111302
@@ -1318,8 +1309,6 @@ def _add_timedelta_arraylike(
13181309
if len(self) != len(other):
13191310
raise ValueError("cannot add indices of unequal length")
13201311

1321-
other = ensure_wrapped_if_datetimelike(other)
1322-
other = cast("TimedeltaArray", other)
13231312
self = cast("DatetimeArray | TimedeltaArray", self)
13241313

13251314
if self._reso != other._reso:
@@ -1379,21 +1368,17 @@ def _sub_nat(self):
13791368
return result.view("timedelta64[ns]")
13801369

13811370
@final
1382-
def _sub_period_array(self, other: PeriodArray) -> npt.NDArray[np.object_]:
1371+
def _sub_periodlike(self, other: Period | PeriodArray) -> npt.NDArray[np.object_]:
1372+
# If the operation is well-defined, we return an object-dtype ndarray
1373+
# of DateOffsets. Null entries are filled with pd.NaT
13831374
if not is_period_dtype(self.dtype):
13841375
raise TypeError(
1385-
f"cannot subtract {other.dtype}-dtype from {type(self).__name__}"
1376+
f"cannot subtract {type(other).__name__} from {type(self).__name__}"
13861377
)
13871378

13881379
self = cast("PeriodArray", self)
1389-
self._require_matching_freq(other)
1390-
1391-
return self._sub_periodlike(other)
1380+
self._check_compatible_with(other)
13921381

1393-
@final
1394-
def _sub_periodlike(self, other: Period | PeriodArray) -> npt.NDArray[np.object_]:
1395-
# caller is responsible for calling
1396-
# require_matching_freq/check_compatible_with
13971382
other_i8, o_mask = self._get_i8_values_and_mask(other)
13981383
new_i8_data = checked_add_with_arr(
13991384
self.asi8, -other_i8, arr_mask=self._isnan, b_mask=o_mask
@@ -1488,6 +1473,7 @@ def _time_shift(
14881473
@unpack_zerodim_and_defer("__add__")
14891474
def __add__(self, other):
14901475
other_dtype = getattr(other, "dtype", None)
1476+
other = ensure_wrapped_if_datetimelike(other)
14911477

14921478
# scalar others
14931479
if other is NaT:
@@ -1548,6 +1534,7 @@ def __radd__(self, other):
15481534
def __sub__(self, other):
15491535

15501536
other_dtype = getattr(other, "dtype", None)
1537+
other = ensure_wrapped_if_datetimelike(other)
15511538

15521539
# scalar others
15531540
if other is NaT:
@@ -1569,7 +1556,7 @@ def __sub__(self, other):
15691556
)
15701557

15711558
elif isinstance(other, Period):
1572-
result = self._sub_period(other)
1559+
result = self._sub_periodlike(other)
15731560

15741561
# array-like others
15751562
elif is_timedelta64_dtype(other_dtype):
@@ -1583,7 +1570,7 @@ def __sub__(self, other):
15831570
result = self._sub_datetime_arraylike(other)
15841571
elif is_period_dtype(other_dtype):
15851572
# PeriodIndex
1586-
result = self._sub_period_array(other)
1573+
result = self._sub_periodlike(other)
15871574
elif is_integer_dtype(other_dtype):
15881575
if not is_period_dtype(self.dtype):
15891576
raise integer_op_not_supported(self)

pandas/tests/arithmetic/test_numeric.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ def test_datetime64_with_index(self):
911911
result = ser - ser.index
912912
tm.assert_series_equal(result, expected)
913913

914-
msg = "cannot subtract period"
914+
msg = "cannot subtract PeriodArray from DatetimeArray"
915915
with pytest.raises(TypeError, match=msg):
916916
# GH#18850
917917
result = ser - ser.index.to_period()

pandas/tests/arithmetic/test_period.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def test_pi_add_sub_td64_array_non_tick_raises(self):
745745

746746
with pytest.raises(TypeError, match=msg):
747747
rng - tdarr
748-
msg = r"cannot subtract period\[Q-DEC\]-dtype from TimedeltaArray"
748+
msg = r"cannot subtract PeriodArray from TimedeltaArray"
749749
with pytest.raises(TypeError, match=msg):
750750
tdarr - rng
751751

0 commit comments

Comments
 (0)