Skip to content

Commit e05b362

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
ENH: TDA fields support non-nano (pandas-dev#47278)
1 parent 3f51046 commit e05b362

File tree

7 files changed

+72
-47
lines changed

7 files changed

+72
-47
lines changed

pandas/_libs/tslibs/fields.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def get_date_field(
2828
def get_timedelta_field(
2929
tdindex: npt.NDArray[np.int64], # const int64_t[:]
3030
field: str,
31+
reso: int = ..., # NPY_DATETIMEUNIT
3132
) -> npt.NDArray[np.int32]: ...
3233
def isleapyear_arr(
3334
years: np.ndarray,

pandas/_libs/tslibs/fields.pyx

+10-6
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ from pandas._libs.tslibs.np_datetime cimport (
4848
get_unit_from_dtype,
4949
npy_datetimestruct,
5050
pandas_datetime_to_datetimestruct,
51+
pandas_timedelta_to_timedeltastruct,
5152
pandas_timedeltastruct,
52-
td64_to_tdstruct,
5353
)
5454

5555

@@ -491,7 +491,11 @@ def get_date_field(const int64_t[:] dtindex, str field, NPY_DATETIMEUNIT reso=NP
491491

492492
@cython.wraparound(False)
493493
@cython.boundscheck(False)
494-
def get_timedelta_field(const int64_t[:] tdindex, str field):
494+
def get_timedelta_field(
495+
const int64_t[:] tdindex,
496+
str field,
497+
NPY_DATETIMEUNIT reso=NPY_FR_ns,
498+
):
495499
"""
496500
Given a int64-based timedelta index, extract the days, hrs, sec.,
497501
field and return an array of these values.
@@ -510,7 +514,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field):
510514
out[i] = -1
511515
continue
512516

513-
td64_to_tdstruct(tdindex[i], &tds)
517+
pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds)
514518
out[i] = tds.days
515519
return out
516520

@@ -521,7 +525,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field):
521525
out[i] = -1
522526
continue
523527

524-
td64_to_tdstruct(tdindex[i], &tds)
528+
pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds)
525529
out[i] = tds.seconds
526530
return out
527531

@@ -532,7 +536,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field):
532536
out[i] = -1
533537
continue
534538

535-
td64_to_tdstruct(tdindex[i], &tds)
539+
pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds)
536540
out[i] = tds.microseconds
537541
return out
538542

@@ -543,7 +547,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field):
543547
out[i] = -1
544548
continue
545549

546-
td64_to_tdstruct(tdindex[i], &tds)
550+
pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds)
547551
out[i] = tds.nanoseconds
548552
return out
549553

pandas/_libs/tslibs/np_datetime.pxd

-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ cdef check_dts_bounds(npy_datetimestruct *dts, NPY_DATETIMEUNIT unit=?)
7777

7878
cdef int64_t dtstruct_to_dt64(npy_datetimestruct* dts) nogil
7979
cdef void dt64_to_dtstruct(int64_t dt64, npy_datetimestruct* out) nogil
80-
cdef void td64_to_tdstruct(int64_t td64, pandas_timedeltastruct* out) nogil
8180

8281
cdef int64_t pydatetime_to_dt64(datetime val, npy_datetimestruct *dts)
8382
cdef int64_t pydate_to_dt64(date val, npy_datetimestruct *dts)

pandas/_libs/tslibs/np_datetime.pyx

-8
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,6 @@ cdef inline void dt64_to_dtstruct(int64_t dt64,
221221
return
222222

223223

224-
cdef inline void td64_to_tdstruct(int64_t td64,
225-
pandas_timedeltastruct* out) nogil:
226-
"""Convenience function to call pandas_timedelta_to_timedeltastruct
227-
with the by-far-most-common frequency NPY_FR_ns"""
228-
pandas_timedelta_to_timedeltastruct(td64, NPY_FR_ns, out)
229-
return
230-
231-
232224
# just exposed for testing at the moment
233225
def py_td64_to_tdstruct(int64_t td64, NPY_DATETIMEUNIT unit):
234226
cdef:

pandas/core/arrays/timedeltas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
def _field_accessor(name: str, alias: str, docstring: str):
7575
def f(self) -> np.ndarray:
7676
values = self.asi8
77-
result = get_timedelta_field(values, alias)
77+
result = get_timedelta_field(values, alias, reso=self._reso)
7878
if self._hasna:
7979
result = self._maybe_mask_results(
8080
result, fill_value=None, convert="float64"

pandas/tests/arrays/test_timedeltas.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,49 @@
11
import numpy as np
22
import pytest
33

4+
from pandas._libs.tslibs.dtypes import NpyDatetimeUnit
5+
46
import pandas as pd
57
from pandas import Timedelta
68
import pandas._testing as tm
79
from pandas.core.arrays import TimedeltaArray
810

911

1012
class TestNonNano:
11-
@pytest.mark.parametrize("unit,reso", [("s", 7), ("ms", 8), ("us", 9)])
13+
@pytest.fixture(params=["s", "ms", "us"])
14+
def unit(self, request):
15+
return request.param
16+
17+
@pytest.fixture
18+
def reso(self, unit):
19+
if unit == "s":
20+
return NpyDatetimeUnit.NPY_FR_s.value
21+
elif unit == "ms":
22+
return NpyDatetimeUnit.NPY_FR_ms.value
23+
elif unit == "us":
24+
return NpyDatetimeUnit.NPY_FR_us.value
25+
else:
26+
raise NotImplementedError(unit)
27+
1228
def test_non_nano(self, unit, reso):
1329
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
1430
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)
1531

1632
assert tda.dtype == arr.dtype
1733
assert tda[0]._reso == reso
1834

35+
@pytest.mark.parametrize("field", TimedeltaArray._field_ops)
36+
def test_fields(self, unit, reso, field):
37+
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
38+
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)
39+
40+
as_nano = arr.astype("m8[ns]")
41+
tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype)
42+
43+
result = getattr(tda, field)
44+
expected = getattr(tda_nano, field)
45+
tm.assert_numpy_array_equal(result, expected)
46+
1947

2048
class TestTimedeltaArray:
2149
@pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"])

pandas/tests/tslibs/test_np_datetime.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
from pandas._libs.tslibs.dtypes import NpyDatetimeUnit
45
from pandas._libs.tslibs.np_datetime import (
56
OutOfBoundsDatetime,
67
OutOfBoundsTimedelta,
@@ -37,42 +38,42 @@ def test_is_unitless():
3738

3839
def test_get_unit_from_dtype():
3940
# datetime64
40-
assert py_get_unit_from_dtype(np.dtype("M8[Y]")) == 0
41-
assert py_get_unit_from_dtype(np.dtype("M8[M]")) == 1
42-
assert py_get_unit_from_dtype(np.dtype("M8[W]")) == 2
41+
assert py_get_unit_from_dtype(np.dtype("M8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value
42+
assert py_get_unit_from_dtype(np.dtype("M8[M]")) == NpyDatetimeUnit.NPY_FR_M.value
43+
assert py_get_unit_from_dtype(np.dtype("M8[W]")) == NpyDatetimeUnit.NPY_FR_W.value
4344
# B has been deprecated and removed -> no 3
44-
assert py_get_unit_from_dtype(np.dtype("M8[D]")) == 4
45-
assert py_get_unit_from_dtype(np.dtype("M8[h]")) == 5
46-
assert py_get_unit_from_dtype(np.dtype("M8[m]")) == 6
47-
assert py_get_unit_from_dtype(np.dtype("M8[s]")) == 7
48-
assert py_get_unit_from_dtype(np.dtype("M8[ms]")) == 8
49-
assert py_get_unit_from_dtype(np.dtype("M8[us]")) == 9
50-
assert py_get_unit_from_dtype(np.dtype("M8[ns]")) == 10
51-
assert py_get_unit_from_dtype(np.dtype("M8[ps]")) == 11
52-
assert py_get_unit_from_dtype(np.dtype("M8[fs]")) == 12
53-
assert py_get_unit_from_dtype(np.dtype("M8[as]")) == 13
45+
assert py_get_unit_from_dtype(np.dtype("M8[D]")) == NpyDatetimeUnit.NPY_FR_D.value
46+
assert py_get_unit_from_dtype(np.dtype("M8[h]")) == NpyDatetimeUnit.NPY_FR_h.value
47+
assert py_get_unit_from_dtype(np.dtype("M8[m]")) == NpyDatetimeUnit.NPY_FR_m.value
48+
assert py_get_unit_from_dtype(np.dtype("M8[s]")) == NpyDatetimeUnit.NPY_FR_s.value
49+
assert py_get_unit_from_dtype(np.dtype("M8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value
50+
assert py_get_unit_from_dtype(np.dtype("M8[us]")) == NpyDatetimeUnit.NPY_FR_us.value
51+
assert py_get_unit_from_dtype(np.dtype("M8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value
52+
assert py_get_unit_from_dtype(np.dtype("M8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value
53+
assert py_get_unit_from_dtype(np.dtype("M8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value
54+
assert py_get_unit_from_dtype(np.dtype("M8[as]")) == NpyDatetimeUnit.NPY_FR_as.value
5455

5556
# timedelta64
56-
assert py_get_unit_from_dtype(np.dtype("m8[Y]")) == 0
57-
assert py_get_unit_from_dtype(np.dtype("m8[M]")) == 1
58-
assert py_get_unit_from_dtype(np.dtype("m8[W]")) == 2
57+
assert py_get_unit_from_dtype(np.dtype("m8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value
58+
assert py_get_unit_from_dtype(np.dtype("m8[M]")) == NpyDatetimeUnit.NPY_FR_M.value
59+
assert py_get_unit_from_dtype(np.dtype("m8[W]")) == NpyDatetimeUnit.NPY_FR_W.value
5960
# B has been deprecated and removed -> no 3
60-
assert py_get_unit_from_dtype(np.dtype("m8[D]")) == 4
61-
assert py_get_unit_from_dtype(np.dtype("m8[h]")) == 5
62-
assert py_get_unit_from_dtype(np.dtype("m8[m]")) == 6
63-
assert py_get_unit_from_dtype(np.dtype("m8[s]")) == 7
64-
assert py_get_unit_from_dtype(np.dtype("m8[ms]")) == 8
65-
assert py_get_unit_from_dtype(np.dtype("m8[us]")) == 9
66-
assert py_get_unit_from_dtype(np.dtype("m8[ns]")) == 10
67-
assert py_get_unit_from_dtype(np.dtype("m8[ps]")) == 11
68-
assert py_get_unit_from_dtype(np.dtype("m8[fs]")) == 12
69-
assert py_get_unit_from_dtype(np.dtype("m8[as]")) == 13
61+
assert py_get_unit_from_dtype(np.dtype("m8[D]")) == NpyDatetimeUnit.NPY_FR_D.value
62+
assert py_get_unit_from_dtype(np.dtype("m8[h]")) == NpyDatetimeUnit.NPY_FR_h.value
63+
assert py_get_unit_from_dtype(np.dtype("m8[m]")) == NpyDatetimeUnit.NPY_FR_m.value
64+
assert py_get_unit_from_dtype(np.dtype("m8[s]")) == NpyDatetimeUnit.NPY_FR_s.value
65+
assert py_get_unit_from_dtype(np.dtype("m8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value
66+
assert py_get_unit_from_dtype(np.dtype("m8[us]")) == NpyDatetimeUnit.NPY_FR_us.value
67+
assert py_get_unit_from_dtype(np.dtype("m8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value
68+
assert py_get_unit_from_dtype(np.dtype("m8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value
69+
assert py_get_unit_from_dtype(np.dtype("m8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value
70+
assert py_get_unit_from_dtype(np.dtype("m8[as]")) == NpyDatetimeUnit.NPY_FR_as.value
7071

7172

7273
def test_td64_to_tdstruct():
7374
val = 12454636234 # arbitrary value
7475

75-
res1 = py_td64_to_tdstruct(val, 10) # ns
76+
res1 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ns.value)
7677
exp1 = {
7778
"days": 0,
7879
"hrs": 0,
@@ -87,7 +88,7 @@ def test_td64_to_tdstruct():
8788
}
8889
assert res1 == exp1
8990

90-
res2 = py_td64_to_tdstruct(val, 9) # us
91+
res2 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_us.value)
9192
exp2 = {
9293
"days": 0,
9394
"hrs": 3,
@@ -102,7 +103,7 @@ def test_td64_to_tdstruct():
102103
}
103104
assert res2 == exp2
104105

105-
res3 = py_td64_to_tdstruct(val, 8) # ms
106+
res3 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ms.value)
106107
exp3 = {
107108
"days": 144,
108109
"hrs": 3,
@@ -118,7 +119,7 @@ def test_td64_to_tdstruct():
118119
assert res3 == exp3
119120

120121
# Note this out of bounds for nanosecond Timedelta
121-
res4 = py_td64_to_tdstruct(val, 7) # s
122+
res4 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_s.value)
122123
exp4 = {
123124
"days": 144150,
124125
"hrs": 21,

0 commit comments

Comments
 (0)