Skip to content

REF: tighter typing in datetimelike arith methods #48962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 24 additions & 38 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,33 +1160,36 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
from pandas.core.arrays.datetimes import tz_to_dtype

assert other is not NaT
other = Timestamp(other)
if other is NaT:
if isna(other):
# i.e. np.datetime64("NaT")
# In this case we specifically interpret NaT as a datetime, not
# the timedelta interpretation we would get by returning self + NaT
result = self._ndarray + NaT.to_datetime64().astype(f"M8[{self._unit}]")
# Preserve our resolution
return DatetimeArray._simple_new(result, dtype=result.dtype)

other = Timestamp(other)
self, other = self._ensure_matching_resos(other)
self = cast("TimedeltaArray", self)

i8 = self.asi8
result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan)
other_i8, o_mask = self._get_i8_values_and_mask(other)
result = checked_add_with_arr(
self.asi8, other_i8, arr_mask=self._isnan, b_mask=o_mask
)
res_values = result.view(f"M8[{self._unit}]")

dtype = tz_to_dtype(tz=other.tz, unit=self._unit)
res_values = result.view(f"M8[{self._unit}]")
return DatetimeArray._simple_new(res_values, dtype=dtype, freq=self.freq)
new_freq = self._get_arithmetic_result_freq(other)
return DatetimeArray._simple_new(res_values, dtype=dtype, freq=new_freq)

@final
def _add_datetime_arraylike(self, other) -> DatetimeArray:
def _add_datetime_arraylike(self, other: DatetimeArray) -> DatetimeArray:
if not is_timedelta64_dtype(self.dtype):
raise TypeError(
f"cannot add {type(self).__name__} and {type(other).__name__}"
)

# At this point we have already checked that other.dtype is datetime64
other = ensure_wrapped_if_datetimelike(other)
# defer to DatetimeArray.__add__
return other + self

Expand All @@ -1208,15 +1211,14 @@ def _sub_datetimelike_scalar(self, other: datetime | np.datetime64):
return self._sub_datetimelike(ts)

@final
def _sub_datetime_arraylike(self, other):
def _sub_datetime_arraylike(self, other: DatetimeArray):
if self.dtype.kind != "M":
raise TypeError(f"cannot subtract a datelike from a {type(self).__name__}")

if len(self) != len(other):
raise ValueError("cannot add indices of unequal length")

self = cast("DatetimeArray", self)
other = ensure_wrapped_if_datetimelike(other)

self, other = self._ensure_matching_resos(other)
return self._sub_datetimelike(other)
Expand All @@ -1242,16 +1244,6 @@ def _sub_datetimelike(self, other: Timestamp | DatetimeArray) -> TimedeltaArray:
new_freq = self._get_arithmetic_result_freq(other)
return TimedeltaArray._simple_new(res_m8, dtype=res_m8.dtype, freq=new_freq)

@final
def _sub_period(self, other: Period) -> npt.NDArray[np.object_]:
if not is_period_dtype(self.dtype):
raise TypeError(f"cannot subtract Period from a {type(self).__name__}")

# If the operation is well-defined, we return an object-dtype ndarray
# of DateOffsets. Null entries are filled with pd.NaT
self._check_compatible_with(other)
return self._sub_periodlike(other)

@final
def _add_period(self, other: Period) -> PeriodArray:
if not is_timedelta64_dtype(self.dtype):
Expand Down Expand Up @@ -1286,9 +1278,7 @@ def _add_timedeltalike_scalar(self, other):
other = Timedelta(other)._as_unit(self._unit)
return self._add_timedeltalike(other)

def _add_timedelta_arraylike(
self, other: TimedeltaArray | npt.NDArray[np.timedelta64]
):
def _add_timedelta_arraylike(self, other: TimedeltaArray):
"""
Add a delta of a TimedeltaIndex

Expand All @@ -1301,12 +1291,10 @@ def _add_timedelta_arraylike(
if len(self) != len(other):
raise ValueError("cannot add indices of unequal length")

other = ensure_wrapped_if_datetimelike(other)
tda = cast("TimedeltaArray", other)
self = cast("DatetimeArray | TimedeltaArray", self)

self, tda = self._ensure_matching_resos(tda)
return self._add_timedeltalike(tda)
self, other = self._ensure_matching_resos(other)
return self._add_timedeltalike(other)

@final
def _add_timedeltalike(self, other: Timedelta | TimedeltaArray):
Expand Down Expand Up @@ -1356,21 +1344,17 @@ def _sub_nat(self):
return result.view("timedelta64[ns]")

@final
def _sub_period_array(self, other: PeriodArray) -> npt.NDArray[np.object_]:
def _sub_periodlike(self, other: Period | PeriodArray) -> npt.NDArray[np.object_]:
# If the operation is well-defined, we return an object-dtype ndarray
# of DateOffsets. Null entries are filled with pd.NaT
if not is_period_dtype(self.dtype):
raise TypeError(
f"cannot subtract {other.dtype}-dtype from {type(self).__name__}"
f"cannot subtract {type(other).__name__} from {type(self).__name__}"
)

self = cast("PeriodArray", self)
self._require_matching_freq(other)

return self._sub_periodlike(other)
self._check_compatible_with(other)

@final
def _sub_periodlike(self, other: Period | PeriodArray) -> npt.NDArray[np.object_]:
# caller is responsible for calling
# require_matching_freq/check_compatible_with
other_i8, o_mask = self._get_i8_values_and_mask(other)
new_i8_data = checked_add_with_arr(
self.asi8, -other_i8, arr_mask=self._isnan, b_mask=o_mask
Expand Down Expand Up @@ -1465,6 +1449,7 @@ def _time_shift(
@unpack_zerodim_and_defer("__add__")
def __add__(self, other):
other_dtype = getattr(other, "dtype", None)
other = ensure_wrapped_if_datetimelike(other)

# scalar others
if other is NaT:
Expand Down Expand Up @@ -1525,6 +1510,7 @@ def __radd__(self, other):
def __sub__(self, other):

other_dtype = getattr(other, "dtype", None)
other = ensure_wrapped_if_datetimelike(other)

# scalar others
if other is NaT:
Expand All @@ -1546,7 +1532,7 @@ def __sub__(self, other):
)

elif isinstance(other, Period):
result = self._sub_period(other)
result = self._sub_periodlike(other)

# array-like others
elif is_timedelta64_dtype(other_dtype):
Expand All @@ -1560,7 +1546,7 @@ def __sub__(self, other):
result = self._sub_datetime_arraylike(other)
elif is_period_dtype(other_dtype):
# PeriodIndex
result = self._sub_period_array(other)
result = self._sub_periodlike(other)
elif is_integer_dtype(other_dtype):
if not is_period_dtype(self.dtype):
raise integer_op_not_supported(self)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arithmetic/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def test_datetime64_with_index(self):
result = ser - ser.index
tm.assert_series_equal(result, expected)

msg = "cannot subtract period"
msg = "cannot subtract PeriodArray from DatetimeArray"
with pytest.raises(TypeError, match=msg):
# GH#18850
result = ser - ser.index.to_period()
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arithmetic/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def test_pi_add_sub_td64_array_non_tick_raises(self):

with pytest.raises(TypeError, match=msg):
rng - tdarr
msg = r"cannot subtract period\[Q-DEC\]-dtype from TimedeltaArray"
msg = r"cannot subtract PeriodArray from TimedeltaArray"
with pytest.raises(TypeError, match=msg):
tdarr - rng

Expand Down