Skip to content

Commit 2eca7e1

Browse files
authored
ENH: TDA+datetime_scalar support non-nano (#47675)
1 parent 669f21f commit 2eca7e1

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

pandas/core/arrays/datetimelike.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
from pandas.util._exceptions import find_stack_level
7474

7575
from pandas.core.dtypes.common import (
76-
DT64NS_DTYPE,
7776
is_all_strings,
7877
is_categorical_dtype,
7978
is_datetime64_any_dtype,
@@ -1103,6 +1102,7 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
11031102
self = cast("TimedeltaArray", self)
11041103

11051104
from pandas.core.arrays import DatetimeArray
1105+
from pandas.core.arrays.datetimes import tz_to_dtype
11061106

11071107
assert other is not NaT
11081108
other = Timestamp(other)
@@ -1113,10 +1113,17 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
11131113
# Preserve our resolution
11141114
return DatetimeArray._simple_new(result, dtype=result.dtype)
11151115

1116+
if self._reso != other._reso:
1117+
raise NotImplementedError(
1118+
"Addition between TimedeltaArray and Timestamp with mis-matched "
1119+
"resolutions is not yet supported."
1120+
)
1121+
11161122
i8 = self.asi8
11171123
result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan)
1118-
dtype = DatetimeTZDtype(tz=other.tz) if other.tz else DT64NS_DTYPE
1119-
return DatetimeArray(result, dtype=dtype, freq=self.freq)
1124+
dtype = tz_to_dtype(tz=other.tz, unit=self._unit)
1125+
res_values = result.view(f"M8[{self._unit}]")
1126+
return DatetimeArray._simple_new(res_values, dtype=dtype, freq=self.freq)
11201127

11211128
@final
11221129
def _add_datetime_arraylike(self, other) -> DatetimeArray:

pandas/core/arrays/datetimes.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,23 @@
9191
_midnight = time(0, 0)
9292

9393

94-
def tz_to_dtype(tz):
94+
def tz_to_dtype(tz: tzinfo | None, unit: str = "ns"):
9595
"""
9696
Return a datetime64[ns] dtype appropriate for the given timezone.
9797
9898
Parameters
9999
----------
100100
tz : tzinfo or None
101+
unit : str, default "ns"
101102
102103
Returns
103104
-------
104105
np.dtype or Datetime64TZDType
105106
"""
106107
if tz is None:
107-
return DT64NS_DTYPE
108+
return np.dtype(f"M8[{unit}]")
108109
else:
109-
return DatetimeTZDtype(tz=tz)
110+
return DatetimeTZDtype(tz=tz, unit=unit)
110111

111112

112113
def _field_accessor(name: str, field: str, docstring=None):
@@ -800,7 +801,7 @@ def tz_convert(self, tz) -> DatetimeArray:
800801
)
801802

802803
# No conversion since timestamps are all UTC to begin with
803-
dtype = tz_to_dtype(tz)
804+
dtype = tz_to_dtype(tz, unit=self._unit)
804805
return self._simple_new(self._ndarray, dtype=dtype, freq=self.freq)
805806

806807
@dtl.ravel_compat
@@ -965,10 +966,14 @@ def tz_localize(self, tz, ambiguous="raise", nonexistent="raise") -> DatetimeArr
965966
# Convert to UTC
966967

967968
new_dates = tzconversion.tz_localize_to_utc(
968-
self.asi8, tz, ambiguous=ambiguous, nonexistent=nonexistent
969+
self.asi8,
970+
tz,
971+
ambiguous=ambiguous,
972+
nonexistent=nonexistent,
973+
reso=self._reso,
969974
)
970-
new_dates = new_dates.view(DT64NS_DTYPE)
971-
dtype = tz_to_dtype(tz)
975+
new_dates = new_dates.view(f"M8[{self._unit}]")
976+
dtype = tz_to_dtype(tz, unit=self._unit)
972977

973978
freq = None
974979
if timezones.is_utc(tz) or (len(self) == 1 and not isna(new_dates[0])):

pandas/tests/arrays/test_timedeltas.py

+28
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,34 @@ def test_add_pdnat(self, tda):
9292
assert result._reso == tda._reso
9393
assert result.isna().all()
9494

95+
# TODO: 2022-07-11 this is the only test that gets to DTA.tz_convert
96+
# or tz_localize with non-nano; implement tests specific to that.
97+
def test_add_datetimelike_scalar(self, tda, tz_naive_fixture):
98+
ts = pd.Timestamp("2016-01-01", tz=tz_naive_fixture)
99+
100+
msg = "with mis-matched resolutions"
101+
with pytest.raises(NotImplementedError, match=msg):
102+
# mismatched reso -> check that we don't give an incorrect result
103+
tda + ts
104+
with pytest.raises(NotImplementedError, match=msg):
105+
# mismatched reso -> check that we don't give an incorrect result
106+
ts + tda
107+
108+
ts = ts._as_unit(tda._unit)
109+
110+
exp_values = tda._ndarray + ts.asm8
111+
expected = (
112+
DatetimeArray._simple_new(exp_values, dtype=exp_values.dtype)
113+
.tz_localize("UTC")
114+
.tz_convert(ts.tz)
115+
)
116+
117+
result = tda + ts
118+
tm.assert_extension_array_equal(result, expected)
119+
120+
result = ts + tda
121+
tm.assert_extension_array_equal(result, expected)
122+
95123
def test_mul_scalar(self, tda):
96124
other = 2
97125
result = tda * other

0 commit comments

Comments
 (0)