Skip to content

Commit e8a4ce2

Browse files
authored
ENH: DTA/TDA add datetimelike scalar with mismatched reso (#48669)
* ENH: DTA/TDA add datetimelike scalar with mismatched reso * mypy fixup
1 parent f490cb9 commit e8a4ce2

File tree

4 files changed

+63
-15
lines changed

4 files changed

+63
-15
lines changed

pandas/_libs/tslibs/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@
3131
"periods_per_day",
3232
"periods_per_second",
3333
"is_supported_unit",
34+
"npy_unit_to_abbrev",
3435
]
3536

3637
from pandas._libs.tslibs import dtypes
3738
from pandas._libs.tslibs.conversion import localize_pydatetime
3839
from pandas._libs.tslibs.dtypes import (
3940
Resolution,
4041
is_supported_unit,
42+
npy_unit_to_abbrev,
4143
periods_per_day,
4244
periods_per_second,
4345
)

pandas/core/arrays/datetimelike.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@
3737
Resolution,
3838
Tick,
3939
Timestamp,
40+
astype_overflowsafe,
4041
delta_to_nanoseconds,
4142
get_unit_from_dtype,
4243
iNaT,
4344
ints_to_pydatetime,
4445
ints_to_pytimedelta,
46+
npy_unit_to_abbrev,
4547
to_offset,
4648
)
4749
from pandas._libs.tslibs.fields import (
@@ -1130,10 +1132,13 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
11301132
return DatetimeArray._simple_new(result, dtype=result.dtype)
11311133

11321134
if self._reso != other._reso:
1133-
raise NotImplementedError(
1134-
"Addition between TimedeltaArray and Timestamp with mis-matched "
1135-
"resolutions is not yet supported."
1136-
)
1135+
# Just as with Timestamp/Timedelta, we cast to the lower resolution
1136+
# so long as doing so is lossless.
1137+
if self._reso < other._reso:
1138+
other = other._as_unit(self._unit, round_ok=False)
1139+
else:
1140+
unit = npy_unit_to_abbrev(other._reso)
1141+
self = self._as_unit(unit)
11371142

11381143
i8 = self.asi8
11391144
result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan)
@@ -1289,10 +1294,12 @@ def _add_timedelta_arraylike(
12891294
self = cast("DatetimeArray | TimedeltaArray", self)
12901295

12911296
if self._reso != other._reso:
1292-
raise NotImplementedError(
1293-
f"Addition of {type(self).__name__} with TimedeltaArray with "
1294-
"mis-matched resolutions is not yet supported."
1295-
)
1297+
# Just as with Timestamp/Timedelta, we cast to the lower resolution
1298+
# so long as doing so is lossless.
1299+
if self._reso < other._reso:
1300+
other = other._as_unit(self._unit)
1301+
else:
1302+
self = self._as_unit(other._unit)
12961303

12971304
self_i8 = self.asi8
12981305
other_i8 = other.asi8
@@ -2028,6 +2035,22 @@ def _unit(self) -> str:
20282035
# "ExtensionDtype"; expected "Union[DatetimeTZDtype, dtype[Any]]"
20292036
return dtype_to_unit(self.dtype) # type: ignore[arg-type]
20302037

2038+
def _as_unit(self: TimelikeOpsT, unit: str) -> TimelikeOpsT:
2039+
dtype = np.dtype(f"{self.dtype.kind}8[{unit}]")
2040+
new_values = astype_overflowsafe(self._ndarray, dtype, round_ok=False)
2041+
2042+
if isinstance(self.dtype, np.dtype):
2043+
new_dtype = new_values.dtype
2044+
else:
2045+
tz = cast("DatetimeArray", self).tz
2046+
new_dtype = DatetimeTZDtype(tz=tz, unit=unit)
2047+
2048+
# error: Unexpected keyword argument "freq" for "_simple_new" of
2049+
# "NDArrayBacked" [call-arg]
2050+
return type(self)._simple_new(
2051+
new_values, dtype=new_dtype, freq=self.freq # type: ignore[call-arg]
2052+
)
2053+
20312054
# --------------------------------------------------------------
20322055

20332056
def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):

pandas/tests/arrays/test_timedeltas.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,18 @@ def test_add_pdnat(self, tda):
104104
def test_add_datetimelike_scalar(self, tda, tz_naive_fixture):
105105
ts = pd.Timestamp("2016-01-01", tz=tz_naive_fixture)
106106

107-
msg = "with mis-matched resolutions"
108-
with pytest.raises(NotImplementedError, match=msg):
107+
expected = tda + ts._as_unit(tda._unit)
108+
res = tda + ts
109+
tm.assert_extension_array_equal(res, expected)
110+
res = ts + tda
111+
tm.assert_extension_array_equal(res, expected)
112+
113+
ts += Timedelta(1) # so we can't cast losslessly
114+
msg = "Cannot losslessly convert units"
115+
with pytest.raises(ValueError, match=msg):
109116
# mismatched reso -> check that we don't give an incorrect result
110117
tda + ts
111-
with pytest.raises(NotImplementedError, match=msg):
118+
with pytest.raises(ValueError, match=msg):
112119
# mismatched reso -> check that we don't give an incorrect result
113120
ts + tda
114121

@@ -179,13 +186,28 @@ def test_add_timedeltaarraylike(self, tda):
179186
tda_nano = TimedeltaArray(tda._ndarray.astype("m8[ns]"))
180187

181188
msg = "mis-matched resolutions is not yet supported"
182-
with pytest.raises(NotImplementedError, match=msg):
189+
expected = tda * 2
190+
res = tda_nano + tda
191+
tm.assert_extension_array_equal(res, expected)
192+
res = tda + tda_nano
193+
tm.assert_extension_array_equal(res, expected)
194+
195+
expected = tda * 0
196+
res = tda - tda_nano
197+
tm.assert_extension_array_equal(res, expected)
198+
199+
res = tda_nano - tda
200+
tm.assert_extension_array_equal(res, expected)
201+
202+
tda_nano[:] = np.timedelta64(1, "ns") # can't round losslessly
203+
msg = "Cannot losslessly cast '-?1 ns' to"
204+
with pytest.raises(ValueError, match=msg):
183205
tda_nano + tda
184-
with pytest.raises(NotImplementedError, match=msg):
206+
with pytest.raises(ValueError, match=msg):
185207
tda + tda_nano
186-
with pytest.raises(NotImplementedError, match=msg):
208+
with pytest.raises(ValueError, match=msg):
187209
tda - tda_nano
188-
with pytest.raises(NotImplementedError, match=msg):
210+
with pytest.raises(ValueError, match=msg):
189211
tda_nano - tda
190212

191213
result = tda_nano + tda_nano

pandas/tests/tslibs/test_api.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def test_namespace():
5656
"periods_per_day",
5757
"periods_per_second",
5858
"is_supported_unit",
59+
"npy_unit_to_abbrev",
5960
]
6061

6162
expected = set(submodules + api)

0 commit comments

Comments
 (0)