From fa536d35b6b7a13f9a78f64a843f16317387674a Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 5 Oct 2022 12:34:13 -0700 Subject: [PATCH] REF: tighter typing in datetimelike arith methods --- pandas/core/arrays/datetimelike.py | 62 ++++++++++--------------- pandas/tests/arithmetic/test_numeric.py | 2 +- pandas/tests/arithmetic/test_period.py | 2 +- 3 files changed, 26 insertions(+), 40 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index d1c793dc6f152..7571202f02665 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1160,33 +1160,36 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray: from pandas.core.arrays.datetimes import tz_to_dtype assert other is not NaT - other = Timestamp(other) - if other is NaT: + if isna(other): + # i.e. np.datetime64("NaT") # In this case we specifically interpret NaT as a datetime, not # the timedelta interpretation we would get by returning self + NaT result = self._ndarray + NaT.to_datetime64().astype(f"M8[{self._unit}]") # Preserve our resolution return DatetimeArray._simple_new(result, dtype=result.dtype) + other = Timestamp(other) self, other = self._ensure_matching_resos(other) self = cast("TimedeltaArray", self) - i8 = self.asi8 - result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan) + other_i8, o_mask = self._get_i8_values_and_mask(other) + result = checked_add_with_arr( + self.asi8, other_i8, arr_mask=self._isnan, b_mask=o_mask + ) + res_values = result.view(f"M8[{self._unit}]") dtype = tz_to_dtype(tz=other.tz, unit=self._unit) res_values = result.view(f"M8[{self._unit}]") - return DatetimeArray._simple_new(res_values, dtype=dtype, freq=self.freq) + new_freq = self._get_arithmetic_result_freq(other) + return DatetimeArray._simple_new(res_values, dtype=dtype, freq=new_freq) @final - def _add_datetime_arraylike(self, other) -> DatetimeArray: + def _add_datetime_arraylike(self, other: DatetimeArray) -> DatetimeArray: if not is_timedelta64_dtype(self.dtype): raise TypeError( f"cannot add {type(self).__name__} and {type(other).__name__}" ) - # At this point we have already checked that other.dtype is datetime64 - other = ensure_wrapped_if_datetimelike(other) # defer to DatetimeArray.__add__ return other + self @@ -1208,7 +1211,7 @@ def _sub_datetimelike_scalar(self, other: datetime | np.datetime64): return self._sub_datetimelike(ts) @final - def _sub_datetime_arraylike(self, other): + def _sub_datetime_arraylike(self, other: DatetimeArray): if self.dtype.kind != "M": raise TypeError(f"cannot subtract a datelike from a {type(self).__name__}") @@ -1216,7 +1219,6 @@ def _sub_datetime_arraylike(self, other): raise ValueError("cannot add indices of unequal length") self = cast("DatetimeArray", self) - other = ensure_wrapped_if_datetimelike(other) self, other = self._ensure_matching_resos(other) return self._sub_datetimelike(other) @@ -1242,16 +1244,6 @@ def _sub_datetimelike(self, other: Timestamp | DatetimeArray) -> TimedeltaArray: new_freq = self._get_arithmetic_result_freq(other) return TimedeltaArray._simple_new(res_m8, dtype=res_m8.dtype, freq=new_freq) - @final - def _sub_period(self, other: Period) -> npt.NDArray[np.object_]: - if not is_period_dtype(self.dtype): - raise TypeError(f"cannot subtract Period from a {type(self).__name__}") - - # If the operation is well-defined, we return an object-dtype ndarray - # of DateOffsets. Null entries are filled with pd.NaT - self._check_compatible_with(other) - return self._sub_periodlike(other) - @final def _add_period(self, other: Period) -> PeriodArray: if not is_timedelta64_dtype(self.dtype): @@ -1286,9 +1278,7 @@ def _add_timedeltalike_scalar(self, other): other = Timedelta(other)._as_unit(self._unit) return self._add_timedeltalike(other) - def _add_timedelta_arraylike( - self, other: TimedeltaArray | npt.NDArray[np.timedelta64] - ): + def _add_timedelta_arraylike(self, other: TimedeltaArray): """ Add a delta of a TimedeltaIndex @@ -1301,12 +1291,10 @@ def _add_timedelta_arraylike( if len(self) != len(other): raise ValueError("cannot add indices of unequal length") - other = ensure_wrapped_if_datetimelike(other) - tda = cast("TimedeltaArray", other) self = cast("DatetimeArray | TimedeltaArray", self) - self, tda = self._ensure_matching_resos(tda) - return self._add_timedeltalike(tda) + self, other = self._ensure_matching_resos(other) + return self._add_timedeltalike(other) @final def _add_timedeltalike(self, other: Timedelta | TimedeltaArray): @@ -1356,21 +1344,17 @@ def _sub_nat(self): return result.view("timedelta64[ns]") @final - def _sub_period_array(self, other: PeriodArray) -> npt.NDArray[np.object_]: + def _sub_periodlike(self, other: Period | PeriodArray) -> npt.NDArray[np.object_]: + # If the operation is well-defined, we return an object-dtype ndarray + # of DateOffsets. Null entries are filled with pd.NaT if not is_period_dtype(self.dtype): raise TypeError( - f"cannot subtract {other.dtype}-dtype from {type(self).__name__}" + f"cannot subtract {type(other).__name__} from {type(self).__name__}" ) self = cast("PeriodArray", self) - self._require_matching_freq(other) - - return self._sub_periodlike(other) + self._check_compatible_with(other) - @final - def _sub_periodlike(self, other: Period | PeriodArray) -> npt.NDArray[np.object_]: - # caller is responsible for calling - # require_matching_freq/check_compatible_with other_i8, o_mask = self._get_i8_values_and_mask(other) new_i8_data = checked_add_with_arr( self.asi8, -other_i8, arr_mask=self._isnan, b_mask=o_mask @@ -1465,6 +1449,7 @@ def _time_shift( @unpack_zerodim_and_defer("__add__") def __add__(self, other): other_dtype = getattr(other, "dtype", None) + other = ensure_wrapped_if_datetimelike(other) # scalar others if other is NaT: @@ -1525,6 +1510,7 @@ def __radd__(self, other): def __sub__(self, other): other_dtype = getattr(other, "dtype", None) + other = ensure_wrapped_if_datetimelike(other) # scalar others if other is NaT: @@ -1546,7 +1532,7 @@ def __sub__(self, other): ) elif isinstance(other, Period): - result = self._sub_period(other) + result = self._sub_periodlike(other) # array-like others elif is_timedelta64_dtype(other_dtype): @@ -1560,7 +1546,7 @@ def __sub__(self, other): result = self._sub_datetime_arraylike(other) elif is_period_dtype(other_dtype): # PeriodIndex - result = self._sub_period_array(other) + result = self._sub_periodlike(other) elif is_integer_dtype(other_dtype): if not is_period_dtype(self.dtype): raise integer_op_not_supported(self) diff --git a/pandas/tests/arithmetic/test_numeric.py b/pandas/tests/arithmetic/test_numeric.py index 881a5f1de1c60..0cb09ba6a4dfb 100644 --- a/pandas/tests/arithmetic/test_numeric.py +++ b/pandas/tests/arithmetic/test_numeric.py @@ -911,7 +911,7 @@ def test_datetime64_with_index(self): result = ser - ser.index tm.assert_series_equal(result, expected) - msg = "cannot subtract period" + msg = "cannot subtract PeriodArray from DatetimeArray" with pytest.raises(TypeError, match=msg): # GH#18850 result = ser - ser.index.to_period() diff --git a/pandas/tests/arithmetic/test_period.py b/pandas/tests/arithmetic/test_period.py index b03ac26a4b74d..56ad0d622cfb6 100644 --- a/pandas/tests/arithmetic/test_period.py +++ b/pandas/tests/arithmetic/test_period.py @@ -745,7 +745,7 @@ def test_pi_add_sub_td64_array_non_tick_raises(self): with pytest.raises(TypeError, match=msg): rng - tdarr - msg = r"cannot subtract period\[Q-DEC\]-dtype from TimedeltaArray" + msg = r"cannot subtract PeriodArray from TimedeltaArray" with pytest.raises(TypeError, match=msg): tdarr - rng