Skip to content

Commit 0dadc71

Browse files
authored
API: Change DTA/TDA add/sub to match numpy (#48748)
* API: Change Timestamp/Timedelta arithmetic to match numpy * fix interval test * DTA/TDA add/sub upcast instead of downcast * test for no-downcasting * add test
1 parent 07023aa commit 0dadc71

File tree

4 files changed

+29
-38
lines changed

4 files changed

+29
-38
lines changed

pandas/core/arrays/datetimelike.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -1134,13 +1134,12 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
11341134
return DatetimeArray._simple_new(result, dtype=result.dtype)
11351135

11361136
if self._reso != other._reso:
1137-
# Just as with Timestamp/Timedelta, we cast to the lower resolution
1138-
# so long as doing so is lossless.
1137+
# Just as with Timestamp/Timedelta, we cast to the higher resolution
11391138
if self._reso < other._reso:
1140-
other = other._as_unit(self._unit, round_ok=False)
1141-
else:
11421139
unit = npy_unit_to_abbrev(other._reso)
11431140
self = self._as_unit(unit)
1141+
else:
1142+
other = other._as_unit(self._unit)
11441143

11451144
i8 = self.asi8
11461145
result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan)
@@ -1296,12 +1295,11 @@ def _add_timedelta_arraylike(
12961295
self = cast("DatetimeArray | TimedeltaArray", self)
12971296

12981297
if self._reso != other._reso:
1299-
# Just as with Timestamp/Timedelta, we cast to the lower resolution
1300-
# so long as doing so is lossless.
1298+
# Just as with Timestamp/Timedelta, we cast to the higher resolution
13011299
if self._reso < other._reso:
1302-
other = other._as_unit(self._unit)
1303-
else:
13041300
self = self._as_unit(other._unit)
1301+
else:
1302+
other = other._as_unit(self._unit)
13051303

13061304
self_i8 = self.asi8
13071305
other_i8 = other.asi8
@@ -2039,7 +2037,7 @@ def _unit(self) -> str:
20392037

20402038
def _as_unit(self: TimelikeOpsT, unit: str) -> TimelikeOpsT:
20412039
dtype = np.dtype(f"{self.dtype.kind}8[{unit}]")
2042-
new_values = astype_overflowsafe(self._ndarray, dtype, round_ok=False)
2040+
new_values = astype_overflowsafe(self._ndarray, dtype, round_ok=True)
20432041

20442042
if isinstance(self.dtype, np.dtype):
20452043
new_dtype = new_values.dtype

pandas/tests/arrays/test_datetimes.py

+11
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,17 @@ def test_compare_mismatched_resolutions(self, comparison_op):
207207
np_res = op(left._ndarray, right._ndarray)
208208
tm.assert_numpy_array_equal(np_res[1:], ~expected[1:])
209209

210+
def test_add_mismatched_reso_doesnt_downcast(self):
211+
# https://github.com/pandas-dev/pandas/pull/48748#issuecomment-1260181008
212+
td = pd.Timedelta(microseconds=1)
213+
dti = pd.date_range("2016-01-01", periods=3) - td
214+
dta = dti._data._as_unit("us")
215+
216+
res = dta + td._as_unit("us")
217+
# even though the result is an even number of days
218+
# (so we _could_ downcast to unit="s"), we do not.
219+
assert res._unit == "us"
220+
210221

211222
class TestDatetimeArrayComparisons:
212223
# TODO: merge this into tests/arithmetic/test_datetime64 once it is

pandas/tests/arrays/test_timedeltas.py

+4-29
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,13 @@ def test_add_pdnat(self, tda):
104104
def test_add_datetimelike_scalar(self, tda, tz_naive_fixture):
105105
ts = pd.Timestamp("2016-01-01", tz=tz_naive_fixture)
106106

107-
expected = tda + ts._as_unit(tda._unit)
107+
expected = tda._as_unit("ns") + ts
108108
res = tda + ts
109109
tm.assert_extension_array_equal(res, expected)
110110
res = ts + tda
111111
tm.assert_extension_array_equal(res, expected)
112112

113-
ts += Timedelta(1) # so we can't cast losslessly
114-
msg = "Cannot losslessly convert units"
115-
with pytest.raises(ValueError, match=msg):
116-
# mismatched reso -> check that we don't give an incorrect result
117-
tda + ts
118-
with pytest.raises(ValueError, match=msg):
119-
# mismatched reso -> check that we don't give an incorrect result
120-
ts + tda
121-
122-
ts = ts._as_unit(tda._unit)
113+
ts += Timedelta(1) # case where we can't cast losslessly
123114

124115
exp_values = tda._ndarray + ts.asm8
125116
expected = (
@@ -185,35 +176,19 @@ def test_add_timedeltaarraylike(self, tda):
185176
# TODO(2.0): just do `tda_nano = tda.astype("m8[ns]")`
186177
tda_nano = TimedeltaArray(tda._ndarray.astype("m8[ns]"))
187178

188-
msg = "mis-matched resolutions is not yet supported"
189-
expected = tda * 2
179+
expected = tda_nano * 2
190180
res = tda_nano + tda
191181
tm.assert_extension_array_equal(res, expected)
192182
res = tda + tda_nano
193183
tm.assert_extension_array_equal(res, expected)
194184

195-
expected = tda * 0
185+
expected = tda_nano * 0
196186
res = tda - tda_nano
197187
tm.assert_extension_array_equal(res, expected)
198188

199189
res = tda_nano - tda
200190
tm.assert_extension_array_equal(res, expected)
201191

202-
tda_nano[:] = np.timedelta64(1, "ns") # can't round losslessly
203-
msg = "Cannot losslessly cast '-?1 ns' to"
204-
with pytest.raises(ValueError, match=msg):
205-
tda_nano + tda
206-
with pytest.raises(ValueError, match=msg):
207-
tda + tda_nano
208-
with pytest.raises(ValueError, match=msg):
209-
tda - tda_nano
210-
with pytest.raises(ValueError, match=msg):
211-
tda_nano - tda
212-
213-
result = tda_nano + tda_nano
214-
expected = tda_nano * 2
215-
tm.assert_extension_array_equal(result, expected)
216-
217192

218193
class TestTimedeltaArray:
219194
@pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"])

pandas/tests/scalar/timestamp/test_timestamp.py

+7
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,13 @@ def test_sub_timedeltalike_mismatched_reso(self, ts_tz):
10301030
assert res == exp
10311031
assert res._reso == max(ts._reso, other._reso)
10321032

1033+
def test_addition_doesnt_downcast_reso(self):
1034+
# https://github.com/pandas-dev/pandas/pull/48748#pullrequestreview-1122635413
1035+
ts = Timestamp(year=2022, month=1, day=1, microsecond=999999)._as_unit("us")
1036+
td = Timedelta(microseconds=1)._as_unit("us")
1037+
res = ts + td
1038+
assert res._reso == ts._reso
1039+
10331040
def test_sub_timedelta64_mismatched_reso(self, ts_tz):
10341041
ts = ts_tz
10351042

0 commit comments

Comments
 (0)