Skip to content

Commit 5620f0e

Browse files
authored
API: preserve reso in Timelta(td64_obj) (pandas-dev#48910)
* BUG: Timedelta.__new__ * remove assertion * GH refs * API: Timedelta(td64_obj) retain resolution * revert extraneous * remove debugging variable
1 parent 4ceb5d9 commit 5620f0e

File tree

8 files changed

+190
-85
lines changed

8 files changed

+190
-85
lines changed

pandas/_libs/tslibs/timedeltas.pyx

+42-51
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ from pandas._libs.tslibs.conversion cimport (
3838
cast_from_unit,
3939
precision_from_unit,
4040
)
41-
from pandas._libs.tslibs.dtypes cimport npy_unit_to_abbrev
41+
from pandas._libs.tslibs.dtypes cimport (
42+
get_supported_reso,
43+
npy_unit_to_abbrev,
44+
)
4245
from pandas._libs.tslibs.nattype cimport (
4346
NPY_NAT,
4447
c_NaT as NaT,
@@ -939,6 +942,7 @@ cdef _timedelta_from_value_and_reso(int64_t value, NPY_DATETIMEUNIT reso):
939942
cdef:
940943
_Timedelta td_base
941944

945+
assert value != NPY_NAT
942946
# For millisecond and second resos, we cannot actually pass int(value) because
943947
# many cases would fall outside of the pytimedelta implementation bounds.
944948
# We pass 0 instead, and override seconds, microseconds, days.
@@ -1704,10 +1708,27 @@ class Timedelta(_Timedelta):
17041708
elif PyDelta_Check(value):
17051709
value = convert_to_timedelta64(value, 'ns')
17061710
elif is_timedelta64_object(value):
1707-
if get_timedelta64_value(value) == NPY_NAT:
1711+
# Retain the resolution if possible, otherwise cast to the nearest
1712+
# supported resolution.
1713+
new_value = get_timedelta64_value(value)
1714+
if new_value == NPY_NAT:
17081715
# i.e. np.timedelta64("NaT")
17091716
return NaT
1710-
value = ensure_td64ns(value)
1717+
1718+
reso = get_datetime64_unit(value)
1719+
new_reso = get_supported_reso(reso)
1720+
if reso != NPY_DATETIMEUNIT.NPY_FR_GENERIC:
1721+
try:
1722+
new_value = convert_reso(
1723+
get_timedelta64_value(value),
1724+
reso,
1725+
new_reso,
1726+
round_ok=True,
1727+
)
1728+
except (OverflowError, OutOfBoundsDatetime) as err:
1729+
raise OutOfBoundsTimedelta(value) from err
1730+
return cls._from_value_and_reso(new_value, reso=new_reso)
1731+
17111732
elif is_tick_object(value):
17121733
value = np.timedelta64(value.nanos, 'ns')
17131734
elif is_integer_object(value) or is_float_object(value):
@@ -1917,9 +1938,15 @@ class Timedelta(_Timedelta):
19171938

19181939
if other.dtype.kind == 'm':
19191940
# also timedelta-like
1920-
if self._reso != NPY_FR_ns:
1921-
raise NotImplementedError
1922-
return _broadcast_floordiv_td64(self.value, other, _floordiv)
1941+
# TODO: could suppress
1942+
# RuntimeWarning: invalid value encountered in floor_divide
1943+
result = self.asm8 // other
1944+
mask = other.view("i8") == NPY_NAT
1945+
if mask.any():
1946+
# We differ from numpy here
1947+
result = result.astype("f8")
1948+
result[mask] = np.nan
1949+
return result
19231950

19241951
elif other.dtype.kind in ['i', 'u', 'f']:
19251952
if other.ndim == 0:
@@ -1951,9 +1978,15 @@ class Timedelta(_Timedelta):
19511978

19521979
if other.dtype.kind == 'm':
19531980
# also timedelta-like
1954-
if self._reso != NPY_FR_ns:
1955-
raise NotImplementedError
1956-
return _broadcast_floordiv_td64(self.value, other, _rfloordiv)
1981+
# TODO: could suppress
1982+
# RuntimeWarning: invalid value encountered in floor_divide
1983+
result = other // self.asm8
1984+
mask = other.view("i8") == NPY_NAT
1985+
if mask.any():
1986+
# We differ from numpy here
1987+
result = result.astype("f8")
1988+
result[mask] = np.nan
1989+
return result
19571990

19581991
# Includes integer array // Timedelta, disallowed in GH#19761
19591992
raise TypeError(f'Invalid dtype {other.dtype} for __floordiv__')
@@ -2003,45 +2036,3 @@ cdef bint _should_cast_to_timedelta(object obj):
20032036
return (
20042037
is_any_td_scalar(obj) or obj is None or obj is NaT or isinstance(obj, str)
20052038
)
2006-
2007-
2008-
cdef _floordiv(int64_t value, right):
2009-
return value // right
2010-
2011-
2012-
cdef _rfloordiv(int64_t value, right):
2013-
# analogous to referencing operator.div, but there is no operator.rfloordiv
2014-
return right // value
2015-
2016-
2017-
cdef _broadcast_floordiv_td64(
2018-
int64_t value,
2019-
ndarray other,
2020-
object (*operation)(int64_t value, object right)
2021-
):
2022-
"""
2023-
Boilerplate code shared by Timedelta.__floordiv__ and
2024-
Timedelta.__rfloordiv__ because np.timedelta64 does not implement these.
2025-
2026-
Parameters
2027-
----------
2028-
value : int64_t; `self.value` from a Timedelta object
2029-
other : ndarray[timedelta64[ns]]
2030-
operation : function, either _floordiv or _rfloordiv
2031-
2032-
Returns
2033-
-------
2034-
result : varies based on `other`
2035-
"""
2036-
# assumes other.dtype.kind == 'm', i.e. other is timedelta-like
2037-
# assumes other.ndim != 0
2038-
2039-
# We need to watch out for np.timedelta64('NaT').
2040-
mask = other.view('i8') == NPY_NAT
2041-
2042-
res = operation(value, other.astype('m8[ns]', copy=False).astype('i8'))
2043-
2044-
if mask.any():
2045-
res = res.astype('f8')
2046-
res[mask] = np.nan
2047-
return res

pandas/core/arrays/numpy_.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import numpy as np
44

55
from pandas._libs import lib
6+
from pandas._libs.tslibs import (
7+
get_unit_from_dtype,
8+
is_supported_unit,
9+
)
610
from pandas._typing import (
711
AxisInt,
812
Dtype,
@@ -439,10 +443,12 @@ def _cmp_method(self, other, op):
439443
def _wrap_ndarray_result(self, result: np.ndarray):
440444
# If we have timedelta64[ns] result, return a TimedeltaArray instead
441445
# of a PandasArray
442-
if result.dtype == "timedelta64[ns]":
446+
if result.dtype.kind == "m" and is_supported_unit(
447+
get_unit_from_dtype(result.dtype)
448+
):
443449
from pandas.core.arrays import TimedeltaArray
444450

445-
return TimedeltaArray._simple_new(result)
451+
return TimedeltaArray._simple_new(result, dtype=result.dtype)
446452
return type(self)(result)
447453

448454
# ------------------------------------------------------------------------

pandas/core/arrays/timedeltas.py

+4
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ def _unbox_scalar(self, value, setitem: bool = False) -> np.timedelta64:
284284
if not isinstance(value, self._scalar_type) and value is not NaT:
285285
raise ValueError("'value' should be a Timedelta.")
286286
self._check_compatible_with(value, setitem=setitem)
287+
if value is NaT:
288+
return np.timedelta64(value.value, "ns")
289+
else:
290+
return value._as_unit(self._unit).asm8
287291
return np.timedelta64(value.value, "ns")
288292

289293
def _scalar_from_string(self, value) -> Timedelta | NaTType:

pandas/core/window/ewm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def _calculate_deltas(
134134
_times = np.asarray(
135135
times.view(np.int64), dtype=np.float64 # type: ignore[union-attr]
136136
)
137-
_halflife = float(Timedelta(halflife).value)
137+
# TODO: generalize to non-nano?
138+
_halflife = float(Timedelta(halflife)._as_unit("ns").value)
138139
return np.diff(_times) / _halflife
139140

140141

pandas/tests/arithmetic/test_numeric.py

+13
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def test_numeric_arr_mul_tdscalar(self, scalar_td, numeric_idx, box_with_array):
204204
box = box_with_array
205205
index = numeric_idx
206206
expected = TimedeltaIndex([Timedelta(days=n) for n in range(len(index))])
207+
if isinstance(scalar_td, np.timedelta64) and box not in [Index, Series]:
208+
# TODO(2.0): once TDA.astype converts to m8, just do expected.astype
209+
tda = expected._data
210+
dtype = scalar_td.dtype
211+
expected = type(tda)._simple_new(tda._ndarray.astype(dtype), dtype=dtype)
207212

208213
index = tm.box_expected(index, box)
209214
expected = tm.box_expected(expected, box)
@@ -249,6 +254,14 @@ def test_numeric_arr_rdiv_tdscalar(self, three_days, numeric_idx, box_with_array
249254
index = numeric_idx[1:3]
250255

251256
expected = TimedeltaIndex(["3 Days", "36 Hours"])
257+
if isinstance(three_days, np.timedelta64) and box not in [Index, Series]:
258+
# TODO(2.0): just use expected.astype
259+
tda = expected._data
260+
dtype = three_days.dtype
261+
if dtype < np.dtype("m8[s]"):
262+
# i.e. resolution is lower -> use lowest supported resolution
263+
dtype = np.dtype("m8[s]")
264+
expected = type(tda)._simple_new(tda._ndarray.astype(dtype), dtype=dtype)
252265

253266
index = tm.box_expected(index, box)
254267
expected = tm.box_expected(expected, box)

pandas/tests/dtypes/cast/test_promote.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,23 @@ def test_maybe_promote_timedelta64_with_any(timedelta64_dtype, any_numpy_dtype_r
463463
[pd.Timedelta(days=1), np.timedelta64(24, "h"), datetime.timedelta(1)],
464464
ids=["pd.Timedelta", "np.timedelta64", "datetime.timedelta"],
465465
)
466-
def test_maybe_promote_any_with_timedelta64(any_numpy_dtype_reduced, fill_value):
466+
def test_maybe_promote_any_with_timedelta64(
467+
any_numpy_dtype_reduced, fill_value, request
468+
):
467469
dtype = np.dtype(any_numpy_dtype_reduced)
468470

469471
# filling anything but timedelta with timedelta casts to object
470472
if is_timedelta64_dtype(dtype):
471473
expected_dtype = dtype
472474
# for timedelta dtypes, scalar values get cast to pd.Timedelta.value
473475
exp_val_for_scalar = pd.Timedelta(fill_value).to_timedelta64()
476+
477+
if isinstance(fill_value, np.timedelta64) and fill_value.dtype != "m8[ns]":
478+
mark = pytest.mark.xfail(
479+
reason="maybe_promote not yet updated to handle non-nano "
480+
"Timedelta scalar"
481+
)
482+
request.node.add_marker(mark)
474483
else:
475484
expected_dtype = np.dtype(object)
476485
exp_val_for_scalar = fill_value

pandas/tests/frame/test_constructors.py

+45-14
Original file line numberDiff line numberDiff line change
@@ -856,16 +856,31 @@ def create_data(constructor):
856856
tm.assert_frame_equal(result_datetime, expected)
857857
tm.assert_frame_equal(result_Timestamp, expected)
858858

859-
def test_constructor_dict_timedelta64_index(self):
859+
@pytest.mark.parametrize(
860+
"klass",
861+
[
862+
pytest.param(
863+
np.timedelta64,
864+
marks=pytest.mark.xfail(
865+
reason="hash mismatch (GH#44504) causes lib.fast_multiget "
866+
"to mess up on dict lookups with equal Timedeltas with "
867+
"mismatched resos"
868+
),
869+
),
870+
timedelta,
871+
Timedelta,
872+
],
873+
)
874+
def test_constructor_dict_timedelta64_index(self, klass):
860875
# GH 10160
861876
td_as_int = [1, 2, 3, 4]
862877

863-
def create_data(constructor):
864-
return {i: {constructor(s): 2 * i} for i, s in enumerate(td_as_int)}
878+
if klass is timedelta:
879+
constructor = lambda x: timedelta(days=x)
880+
else:
881+
constructor = lambda x: klass(x, "D")
865882

866-
data_timedelta64 = create_data(lambda x: np.timedelta64(x, "D"))
867-
data_timedelta = create_data(lambda x: timedelta(days=x))
868-
data_Timedelta = create_data(lambda x: Timedelta(x, "D"))
883+
data = {i: {constructor(s): 2 * i} for i, s in enumerate(td_as_int)}
869884

870885
expected = DataFrame(
871886
[
@@ -877,12 +892,8 @@ def create_data(constructor):
877892
index=[Timedelta(td, "D") for td in td_as_int],
878893
)
879894

880-
result_timedelta64 = DataFrame(data_timedelta64)
881-
result_timedelta = DataFrame(data_timedelta)
882-
result_Timedelta = DataFrame(data_Timedelta)
883-
tm.assert_frame_equal(result_timedelta64, expected)
884-
tm.assert_frame_equal(result_timedelta, expected)
885-
tm.assert_frame_equal(result_Timedelta, expected)
895+
result = DataFrame(data)
896+
tm.assert_frame_equal(result, expected)
886897

887898
def test_constructor_period_dict(self):
888899
# PeriodIndex
@@ -3111,14 +3122,34 @@ def test_from_out_of_bounds_datetime(self, constructor, cls):
31113122

31123123
assert type(get1(result)) is cls
31133124

3125+
@pytest.mark.xfail(
3126+
reason="TimedeltaArray constructor has been updated to cast td64 to non-nano, "
3127+
"but TimedeltaArray._from_sequence has not"
3128+
)
31143129
@pytest.mark.parametrize("cls", [timedelta, np.timedelta64])
3115-
def test_from_out_of_bounds_timedelta(self, constructor, cls):
3130+
def test_from_out_of_bounds_ns_timedelta(self, constructor, cls):
3131+
# scalar that won't fit in nanosecond td64, but will fit in microsecond
31163132
scalar = datetime(9999, 1, 1) - datetime(1970, 1, 1)
3133+
exp_dtype = "m8[us]" # smallest reso that fits
31173134
if cls is np.timedelta64:
31183135
scalar = np.timedelta64(scalar, "D")
3136+
exp_dtype = "m8[s]" # closest reso to input
31193137
result = constructor(scalar)
31203138

3121-
assert type(get1(result)) is cls
3139+
item = get1(result)
3140+
dtype = result.dtype if isinstance(result, Series) else result.dtypes.iloc[0]
3141+
3142+
assert type(item) is Timedelta
3143+
assert item.asm8.dtype == exp_dtype
3144+
assert dtype == exp_dtype
3145+
3146+
def test_out_of_s_bounds_timedelta64(self, constructor):
3147+
scalar = np.timedelta64(np.iinfo(np.int64).max, "D")
3148+
result = constructor(scalar)
3149+
item = get1(result)
3150+
assert type(item) is np.timedelta64
3151+
dtype = result.dtype if isinstance(result, Series) else result.dtypes.iloc[0]
3152+
assert dtype == object
31223153

31233154
def test_tzaware_data_tznaive_dtype(self, constructor):
31243155
tz = "US/Eastern"

0 commit comments

Comments
 (0)