Skip to content

Commit 3d2cbfe

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
ENH: TimedeltaArray add/sub with NaT preserve reso (pandas-dev#47522)
* ENH: TimedeltaArray add/sub with NaT preserve reso * mypy fixup * use datetime_data
1 parent 3f97fbd commit 3d2cbfe

File tree

2 files changed

+73
-26
lines changed

2 files changed

+73
-26
lines changed

pandas/core/arrays/datetimelike.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ def _cmp_method(self, other, op):
10871087
__rdivmod__ = make_invalid_op("__rdivmod__")
10881088

10891089
@final
1090-
def _add_datetimelike_scalar(self, other):
1090+
def _add_datetimelike_scalar(self, other) -> DatetimeArray:
10911091
if not is_timedelta64_dtype(self.dtype):
10921092
raise TypeError(
10931093
f"cannot add {type(self).__name__} and {type(other).__name__}"
@@ -1102,16 +1102,12 @@ def _add_datetimelike_scalar(self, other):
11021102
if other is NaT:
11031103
# In this case we specifically interpret NaT as a datetime, not
11041104
# the timedelta interpretation we would get by returning self + NaT
1105-
result = self.asi8.view("m8[ms]") + NaT.to_datetime64()
1106-
return DatetimeArray(result)
1105+
result = self._ndarray + NaT.to_datetime64().astype(f"M8[{self._unit}]")
1106+
# Preserve our resolution
1107+
return DatetimeArray._simple_new(result, dtype=result.dtype)
11071108

11081109
i8 = self.asi8
1109-
# Incompatible types in assignment (expression has type "ndarray[Any,
1110-
# dtype[signedinteger[_64Bit]]]", variable has type
1111-
# "ndarray[Any, dtype[datetime64]]")
1112-
result = checked_add_with_arr( # type: ignore[assignment]
1113-
i8, other.value, arr_mask=self._isnan
1114-
)
1110+
result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan)
11151111
dtype = DatetimeTZDtype(tz=other.tz) if other.tz else DT64NS_DTYPE
11161112
return DatetimeArray(result, dtype=dtype, freq=self.freq)
11171113

@@ -1275,12 +1271,14 @@ def _add_nat(self):
12751271
raise TypeError(
12761272
f"Cannot add {type(self).__name__} and {type(NaT).__name__}"
12771273
)
1274+
self = cast("TimedeltaArray | DatetimeArray", self)
12781275

12791276
# GH#19124 pd.NaT is treated like a timedelta for both timedelta
12801277
# and datetime dtypes
12811278
result = np.empty(self.shape, dtype=np.int64)
12821279
result.fill(iNaT)
1283-
return type(self)(result, dtype=self.dtype, freq=None)
1280+
result = result.view(self._ndarray.dtype) # preserve reso
1281+
return type(self)._simple_new(result, dtype=self.dtype, freq=None)
12841282

12851283
@final
12861284
def _sub_nat(self):
@@ -1905,6 +1903,13 @@ class TimelikeOps(DatetimeLikeArrayMixin):
19051903
def _reso(self) -> int:
19061904
return get_unit_from_dtype(self._ndarray.dtype)
19071905

1906+
@cache_readonly
1907+
def _unit(self) -> str:
1908+
# e.g. "ns", "us", "ms"
1909+
# error: Argument 1 to "dtype_to_unit" has incompatible type
1910+
# "ExtensionDtype"; expected "Union[DatetimeTZDtype, dtype[Any]]"
1911+
return dtype_to_unit(self.dtype) # type: ignore[arg-type]
1912+
19081913
def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
19091914
if (
19101915
ufunc in [np.isnan, np.isinf, np.isfinite]
@@ -2105,3 +2110,21 @@ def maybe_infer_freq(freq):
21052110
freq_infer = True
21062111
freq = None
21072112
return freq, freq_infer
2113+
2114+
2115+
def dtype_to_unit(dtype: DatetimeTZDtype | np.dtype) -> str:
2116+
"""
2117+
Return the unit str corresponding to the dtype's resolution.
2118+
2119+
Parameters
2120+
----------
2121+
dtype : DatetimeTZDtype or np.dtype
2122+
If np.dtype, we assume it is a datetime64 dtype.
2123+
2124+
Returns
2125+
-------
2126+
str
2127+
"""
2128+
if isinstance(dtype, DatetimeTZDtype):
2129+
return dtype.unit
2130+
return np.datetime_data(dtype)[0]

pandas/tests/arrays/test_timedeltas.py

+40-16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
import pandas as pd
77
from pandas import Timedelta
88
import pandas._testing as tm
9-
from pandas.core.arrays import TimedeltaArray
9+
from pandas.core.arrays import (
10+
DatetimeArray,
11+
TimedeltaArray,
12+
)
1013

1114

1215
class TestNonNano:
@@ -25,6 +28,11 @@ def reso(self, unit):
2528
else:
2629
raise NotImplementedError(unit)
2730

31+
@pytest.fixture
32+
def tda(self, unit):
33+
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
34+
return TimedeltaArray._simple_new(arr, dtype=arr.dtype)
35+
2836
def test_non_nano(self, unit, reso):
2937
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
3038
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)
@@ -33,39 +41,55 @@ def test_non_nano(self, unit, reso):
3341
assert tda[0]._reso == reso
3442

3543
@pytest.mark.parametrize("field", TimedeltaArray._field_ops)
36-
def test_fields(self, unit, field):
37-
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
38-
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)
39-
40-
as_nano = arr.astype("m8[ns]")
44+
def test_fields(self, tda, field):
45+
as_nano = tda._ndarray.astype("m8[ns]")
4146
tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype)
4247

4348
result = getattr(tda, field)
4449
expected = getattr(tda_nano, field)
4550
tm.assert_numpy_array_equal(result, expected)
4651

47-
def test_to_pytimedelta(self, unit):
48-
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
49-
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)
50-
51-
as_nano = arr.astype("m8[ns]")
52+
def test_to_pytimedelta(self, tda):
53+
as_nano = tda._ndarray.astype("m8[ns]")
5254
tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype)
5355

5456
result = tda.to_pytimedelta()
5557
expected = tda_nano.to_pytimedelta()
5658
tm.assert_numpy_array_equal(result, expected)
5759

58-
def test_total_seconds(self, unit):
59-
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
60-
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)
61-
62-
as_nano = arr.astype("m8[ns]")
60+
def test_total_seconds(self, unit, tda):
61+
as_nano = tda._ndarray.astype("m8[ns]")
6362
tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype)
6463

6564
result = tda.total_seconds()
6665
expected = tda_nano.total_seconds()
6766
tm.assert_numpy_array_equal(result, expected)
6867

68+
@pytest.mark.parametrize(
69+
"nat", [np.datetime64("NaT", "ns"), np.datetime64("NaT", "us")]
70+
)
71+
def test_add_nat_datetimelike_scalar(self, nat, tda):
72+
result = tda + nat
73+
assert isinstance(result, DatetimeArray)
74+
assert result._reso == tda._reso
75+
assert result.isna().all()
76+
77+
result = nat + tda
78+
assert isinstance(result, DatetimeArray)
79+
assert result._reso == tda._reso
80+
assert result.isna().all()
81+
82+
def test_add_pdnat(self, tda):
83+
result = tda + pd.NaT
84+
assert isinstance(result, TimedeltaArray)
85+
assert result._reso == tda._reso
86+
assert result.isna().all()
87+
88+
result = pd.NaT + tda
89+
assert isinstance(result, TimedeltaArray)
90+
assert result._reso == tda._reso
91+
assert result.isna().all()
92+
6993

7094
class TestTimedeltaArray:
7195
@pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"])

0 commit comments

Comments
 (0)