From 14e80326d4514158210be26fb1a13c02e5e778d6 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 27 Jun 2022 16:51:00 -0700 Subject: [PATCH 1/3] ENH: TimedeltaArray add/sub with NaT preserve reso --- pandas/core/arrays/datetimelike.py | 33 +++++++++++++-- pandas/tests/arrays/test_timedeltas.py | 56 ++++++++++++++++++-------- 2 files changed, 69 insertions(+), 20 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index d354d28d0f46f..cb857cc1aa6be 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1087,7 +1087,7 @@ def _cmp_method(self, other, op): __rdivmod__ = make_invalid_op("__rdivmod__") @final - def _add_datetimelike_scalar(self, other): + def _add_datetimelike_scalar(self, other) -> DatetimeArray: if not is_timedelta64_dtype(self.dtype): raise TypeError( f"cannot add {type(self).__name__} and {type(other).__name__}" @@ -1102,8 +1102,9 @@ def _add_datetimelike_scalar(self, other): if other is NaT: # In this case we specifically interpret NaT as a datetime, not # the timedelta interpretation we would get by returning self + NaT - result = self.asi8.view("m8[ms]") + NaT.to_datetime64() - return DatetimeArray(result) + result = self._ndarray + NaT.to_datetime64().astype(f"M8[{self._unit}]") + # Preserve our resolution + return DatetimeArray._simple_new(result, dtype=result.dtype) i8 = self.asi8 # Incompatible types in assignment (expression has type "ndarray[Any, @@ -1280,7 +1281,8 @@ def _add_nat(self): # and datetime dtypes result = np.empty(self.shape, dtype=np.int64) result.fill(iNaT) - return type(self)(result, dtype=self.dtype, freq=None) + result = result.view(self._ndarray.dtype) # preserve reso + return type(self)._simple_new(result, dtype=self.dtype, freq=None) @final def _sub_nat(self): @@ -1905,6 +1907,11 @@ class TimelikeOps(DatetimeLikeArrayMixin): def _reso(self) -> int: return get_unit_from_dtype(self._ndarray.dtype) + @cache_readonly + def _unit(self) -> str: + # e.g. "ns", "us", "ms" + return dtype_to_unit(self.dtype) + def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): if ( ufunc in [np.isnan, np.isinf, np.isfinite] @@ -2105,3 +2112,21 @@ def maybe_infer_freq(freq): freq_infer = True freq = None return freq, freq_infer + + +def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype) -> str: + """ + Return the unit str corresponding to the dtype's resolution. + + Parameters + ---------- + dtype : DatetimeTZDtype or np.dtype + If np.dtype, we assume it is a datetime64 dtype. + + Returns + ------- + str + """ + if isinstance(dtype, DatetimeTZDtype): + return dtype.unit + return str(dtype).split("[")[-1][:-1] diff --git a/pandas/tests/arrays/test_timedeltas.py b/pandas/tests/arrays/test_timedeltas.py index 5983c2f644949..36acb8f0fe389 100644 --- a/pandas/tests/arrays/test_timedeltas.py +++ b/pandas/tests/arrays/test_timedeltas.py @@ -6,7 +6,10 @@ import pandas as pd from pandas import Timedelta import pandas._testing as tm -from pandas.core.arrays import TimedeltaArray +from pandas.core.arrays import ( + DatetimeArray, + TimedeltaArray, +) class TestNonNano: @@ -25,6 +28,11 @@ def reso(self, unit): else: raise NotImplementedError(unit) + @pytest.fixture + def tda(self, unit): + arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") + return TimedeltaArray._simple_new(arr, dtype=arr.dtype) + def test_non_nano(self, unit, reso): arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype) @@ -33,39 +41,55 @@ def test_non_nano(self, unit, reso): assert tda[0]._reso == reso @pytest.mark.parametrize("field", TimedeltaArray._field_ops) - def test_fields(self, unit, field): - arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") - tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype) - - as_nano = arr.astype("m8[ns]") + def test_fields(self, tda, field): + as_nano = tda._ndarray.astype("m8[ns]") tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) result = getattr(tda, field) expected = getattr(tda_nano, field) tm.assert_numpy_array_equal(result, expected) - def test_to_pytimedelta(self, unit): - arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") - tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype) - - as_nano = arr.astype("m8[ns]") + def test_to_pytimedelta(self, tda): + as_nano = tda._ndarray.astype("m8[ns]") tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) result = tda.to_pytimedelta() expected = tda_nano.to_pytimedelta() tm.assert_numpy_array_equal(result, expected) - def test_total_seconds(self, unit): - arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") - tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype) - - as_nano = arr.astype("m8[ns]") + def test_total_seconds(self, unit, tda): + as_nano = tda._ndarray.astype("m8[ns]") tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) result = tda.total_seconds() expected = tda_nano.total_seconds() tm.assert_numpy_array_equal(result, expected) + @pytest.mark.parametrize( + "nat", [np.datetime64("NaT", "ns"), np.datetime64("NaT", "us")] + ) + def test_add_nat_datetimelike_scalar(self, nat, tda): + result = tda + nat + assert isinstance(result, DatetimeArray) + assert result._reso == tda._reso + assert result.isna().all() + + result = nat + tda + assert isinstance(result, DatetimeArray) + assert result._reso == tda._reso + assert result.isna().all() + + def test_add_pdnat(self, tda): + result = tda + pd.NaT + assert isinstance(result, TimedeltaArray) + assert result._reso == tda._reso + assert result.isna().all() + + result = pd.NaT + tda + assert isinstance(result, TimedeltaArray) + assert result._reso == tda._reso + assert result.isna().all() + class TestTimedeltaArray: @pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"]) From be148287db5a2c4e95c09df689d46dbb4a5f35f4 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 27 Jun 2022 18:56:45 -0700 Subject: [PATCH 2/3] mypy fixup --- pandas/core/arrays/datetimelike.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index cb857cc1aa6be..6e3d4e3d22a61 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1107,12 +1107,7 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray: return DatetimeArray._simple_new(result, dtype=result.dtype) i8 = self.asi8 - # Incompatible types in assignment (expression has type "ndarray[Any, - # dtype[signedinteger[_64Bit]]]", variable has type - # "ndarray[Any, dtype[datetime64]]") - result = checked_add_with_arr( # type: ignore[assignment] - i8, other.value, arr_mask=self._isnan - ) + result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan) dtype = DatetimeTZDtype(tz=other.tz) if other.tz else DT64NS_DTYPE return DatetimeArray(result, dtype=dtype, freq=self.freq) @@ -1276,6 +1271,7 @@ def _add_nat(self): raise TypeError( f"Cannot add {type(self).__name__} and {type(NaT).__name__}" ) + self = cast("TimedeltaArray | DatetimeArray", self) # GH#19124 pd.NaT is treated like a timedelta for both timedelta # and datetime dtypes @@ -1910,7 +1906,9 @@ def _reso(self) -> int: @cache_readonly def _unit(self) -> str: # e.g. "ns", "us", "ms" - return dtype_to_unit(self.dtype) + # error: Argument 1 to "dtype_to_unit" has incompatible type + # "ExtensionDtype"; expected "Union[DatetimeTZDtype, dtype[Any]]" + return dtype_to_unit(self.dtype) # type: ignore[arg-type] def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): if ( From 499cb6f520cd367c570151c1d89e0ad833efe861 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 28 Jun 2022 09:50:08 -0700 Subject: [PATCH 3/3] use datetime_data --- pandas/core/arrays/datetimelike.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 6e3d4e3d22a61..5e65b124ae0f4 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -2127,4 +2127,4 @@ def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype) -> str: """ if isinstance(dtype, DatetimeTZDtype): return dtype.unit - return str(dtype).split("[")[-1][:-1] + return np.datetime_data(dtype)[0]