From d9b4714466f5153dc0f607a4803a8d70ef504fd8 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 7 Jun 2022 14:50:34 -0700 Subject: [PATCH] ENH: TDA fields support non-nano --- pandas/_libs/tslibs/fields.pyi | 1 + pandas/_libs/tslibs/fields.pyx | 16 ++++--- pandas/_libs/tslibs/np_datetime.pxd | 1 - pandas/_libs/tslibs/np_datetime.pyx | 8 ---- pandas/core/arrays/timedeltas.py | 2 +- pandas/tests/arrays/test_timedeltas.py | 30 +++++++++++- pandas/tests/tslibs/test_np_datetime.py | 61 +++++++++++++------------ 7 files changed, 72 insertions(+), 47 deletions(-) diff --git a/pandas/_libs/tslibs/fields.pyi b/pandas/_libs/tslibs/fields.pyi index 71363ad836370..8b4bc1a31a1aa 100644 --- a/pandas/_libs/tslibs/fields.pyi +++ b/pandas/_libs/tslibs/fields.pyi @@ -28,6 +28,7 @@ def get_date_field( def get_timedelta_field( tdindex: npt.NDArray[np.int64], # const int64_t[:] field: str, + reso: int = ..., # NPY_DATETIMEUNIT ) -> npt.NDArray[np.int32]: ... def isleapyear_arr( years: np.ndarray, diff --git a/pandas/_libs/tslibs/fields.pyx b/pandas/_libs/tslibs/fields.pyx index bc5e5b37b9a76..71a0f2727445f 100644 --- a/pandas/_libs/tslibs/fields.pyx +++ b/pandas/_libs/tslibs/fields.pyx @@ -48,8 +48,8 @@ from pandas._libs.tslibs.np_datetime cimport ( get_unit_from_dtype, npy_datetimestruct, pandas_datetime_to_datetimestruct, + pandas_timedelta_to_timedeltastruct, pandas_timedeltastruct, - td64_to_tdstruct, ) @@ -491,7 +491,11 @@ def get_date_field(const int64_t[:] dtindex, str field, NPY_DATETIMEUNIT reso=NP @cython.wraparound(False) @cython.boundscheck(False) -def get_timedelta_field(const int64_t[:] tdindex, str field): +def get_timedelta_field( + const int64_t[:] tdindex, + str field, + NPY_DATETIMEUNIT reso=NPY_FR_ns, +): """ Given a int64-based timedelta index, extract the days, hrs, sec., field and return an array of these values. @@ -510,7 +514,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field): out[i] = -1 continue - td64_to_tdstruct(tdindex[i], &tds) + pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds) out[i] = tds.days return out @@ -521,7 +525,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field): out[i] = -1 continue - td64_to_tdstruct(tdindex[i], &tds) + pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds) out[i] = tds.seconds return out @@ -532,7 +536,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field): out[i] = -1 continue - td64_to_tdstruct(tdindex[i], &tds) + pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds) out[i] = tds.microseconds return out @@ -543,7 +547,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field): out[i] = -1 continue - td64_to_tdstruct(tdindex[i], &tds) + pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds) out[i] = tds.nanoseconds return out diff --git a/pandas/_libs/tslibs/np_datetime.pxd b/pandas/_libs/tslibs/np_datetime.pxd index f072dab3763aa..d4dbcbe2acd6e 100644 --- a/pandas/_libs/tslibs/np_datetime.pxd +++ b/pandas/_libs/tslibs/np_datetime.pxd @@ -77,7 +77,6 @@ cdef check_dts_bounds(npy_datetimestruct *dts, NPY_DATETIMEUNIT unit=?) cdef int64_t dtstruct_to_dt64(npy_datetimestruct* dts) nogil cdef void dt64_to_dtstruct(int64_t dt64, npy_datetimestruct* out) nogil -cdef void td64_to_tdstruct(int64_t td64, pandas_timedeltastruct* out) nogil cdef int64_t pydatetime_to_dt64(datetime val, npy_datetimestruct *dts) cdef int64_t pydate_to_dt64(date val, npy_datetimestruct *dts) diff --git a/pandas/_libs/tslibs/np_datetime.pyx b/pandas/_libs/tslibs/np_datetime.pyx index bab4743dce38c..cf967509a84c0 100644 --- a/pandas/_libs/tslibs/np_datetime.pyx +++ b/pandas/_libs/tslibs/np_datetime.pyx @@ -221,14 +221,6 @@ cdef inline void dt64_to_dtstruct(int64_t dt64, return -cdef inline void td64_to_tdstruct(int64_t td64, - pandas_timedeltastruct* out) nogil: - """Convenience function to call pandas_timedelta_to_timedeltastruct - with the by-far-most-common frequency NPY_FR_ns""" - pandas_timedelta_to_timedeltastruct(td64, NPY_FR_ns, out) - return - - # just exposed for testing at the moment def py_td64_to_tdstruct(int64_t td64, NPY_DATETIMEUNIT unit): cdef: diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index 36e8e44e2034f..3bbb03d88e38d 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -74,7 +74,7 @@ def _field_accessor(name: str, alias: str, docstring: str): def f(self) -> np.ndarray: values = self.asi8 - result = get_timedelta_field(values, alias) + result = get_timedelta_field(values, alias, reso=self._reso) if self._hasna: result = self._maybe_mask_results( result, fill_value=None, convert="float64" diff --git a/pandas/tests/arrays/test_timedeltas.py b/pandas/tests/arrays/test_timedeltas.py index 46306167878f6..c8b850d35035a 100644 --- a/pandas/tests/arrays/test_timedeltas.py +++ b/pandas/tests/arrays/test_timedeltas.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit + import pandas as pd from pandas import Timedelta import pandas._testing as tm @@ -8,7 +10,21 @@ class TestNonNano: - @pytest.mark.parametrize("unit,reso", [("s", 7), ("ms", 8), ("us", 9)]) + @pytest.fixture(params=["s", "ms", "us"]) + def unit(self, request): + return request.param + + @pytest.fixture + def reso(self, unit): + if unit == "s": + return NpyDatetimeUnit.NPY_FR_s.value + elif unit == "ms": + return NpyDatetimeUnit.NPY_FR_ms.value + elif unit == "us": + return NpyDatetimeUnit.NPY_FR_us.value + else: + raise NotImplementedError(unit) + def test_non_nano(self, unit, reso): arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype) @@ -16,6 +32,18 @@ def test_non_nano(self, unit, reso): assert tda.dtype == arr.dtype assert tda[0]._reso == reso + @pytest.mark.parametrize("field", TimedeltaArray._field_ops) + def test_fields(self, unit, reso, field): + arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") + tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype) + + as_nano = arr.astype("m8[ns]") + tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) + + result = getattr(tda, field) + expected = getattr(tda_nano, field) + tm.assert_numpy_array_equal(result, expected) + class TestTimedeltaArray: @pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"]) diff --git a/pandas/tests/tslibs/test_np_datetime.py b/pandas/tests/tslibs/test_np_datetime.py index 9ae491f1618f6..cc09f0fc77039 100644 --- a/pandas/tests/tslibs/test_np_datetime.py +++ b/pandas/tests/tslibs/test_np_datetime.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit from pandas._libs.tslibs.np_datetime import ( OutOfBoundsDatetime, OutOfBoundsTimedelta, @@ -37,42 +38,42 @@ def test_is_unitless(): def test_get_unit_from_dtype(): # datetime64 - assert py_get_unit_from_dtype(np.dtype("M8[Y]")) == 0 - assert py_get_unit_from_dtype(np.dtype("M8[M]")) == 1 - assert py_get_unit_from_dtype(np.dtype("M8[W]")) == 2 + assert py_get_unit_from_dtype(np.dtype("M8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value + assert py_get_unit_from_dtype(np.dtype("M8[M]")) == NpyDatetimeUnit.NPY_FR_M.value + assert py_get_unit_from_dtype(np.dtype("M8[W]")) == NpyDatetimeUnit.NPY_FR_W.value # B has been deprecated and removed -> no 3 - assert py_get_unit_from_dtype(np.dtype("M8[D]")) == 4 - assert py_get_unit_from_dtype(np.dtype("M8[h]")) == 5 - assert py_get_unit_from_dtype(np.dtype("M8[m]")) == 6 - assert py_get_unit_from_dtype(np.dtype("M8[s]")) == 7 - assert py_get_unit_from_dtype(np.dtype("M8[ms]")) == 8 - assert py_get_unit_from_dtype(np.dtype("M8[us]")) == 9 - assert py_get_unit_from_dtype(np.dtype("M8[ns]")) == 10 - assert py_get_unit_from_dtype(np.dtype("M8[ps]")) == 11 - assert py_get_unit_from_dtype(np.dtype("M8[fs]")) == 12 - assert py_get_unit_from_dtype(np.dtype("M8[as]")) == 13 + assert py_get_unit_from_dtype(np.dtype("M8[D]")) == NpyDatetimeUnit.NPY_FR_D.value + assert py_get_unit_from_dtype(np.dtype("M8[h]")) == NpyDatetimeUnit.NPY_FR_h.value + assert py_get_unit_from_dtype(np.dtype("M8[m]")) == NpyDatetimeUnit.NPY_FR_m.value + assert py_get_unit_from_dtype(np.dtype("M8[s]")) == NpyDatetimeUnit.NPY_FR_s.value + assert py_get_unit_from_dtype(np.dtype("M8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value + assert py_get_unit_from_dtype(np.dtype("M8[us]")) == NpyDatetimeUnit.NPY_FR_us.value + assert py_get_unit_from_dtype(np.dtype("M8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value + assert py_get_unit_from_dtype(np.dtype("M8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value + assert py_get_unit_from_dtype(np.dtype("M8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value + assert py_get_unit_from_dtype(np.dtype("M8[as]")) == NpyDatetimeUnit.NPY_FR_as.value # timedelta64 - assert py_get_unit_from_dtype(np.dtype("m8[Y]")) == 0 - assert py_get_unit_from_dtype(np.dtype("m8[M]")) == 1 - assert py_get_unit_from_dtype(np.dtype("m8[W]")) == 2 + assert py_get_unit_from_dtype(np.dtype("m8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value + assert py_get_unit_from_dtype(np.dtype("m8[M]")) == NpyDatetimeUnit.NPY_FR_M.value + assert py_get_unit_from_dtype(np.dtype("m8[W]")) == NpyDatetimeUnit.NPY_FR_W.value # B has been deprecated and removed -> no 3 - assert py_get_unit_from_dtype(np.dtype("m8[D]")) == 4 - assert py_get_unit_from_dtype(np.dtype("m8[h]")) == 5 - assert py_get_unit_from_dtype(np.dtype("m8[m]")) == 6 - assert py_get_unit_from_dtype(np.dtype("m8[s]")) == 7 - assert py_get_unit_from_dtype(np.dtype("m8[ms]")) == 8 - assert py_get_unit_from_dtype(np.dtype("m8[us]")) == 9 - assert py_get_unit_from_dtype(np.dtype("m8[ns]")) == 10 - assert py_get_unit_from_dtype(np.dtype("m8[ps]")) == 11 - assert py_get_unit_from_dtype(np.dtype("m8[fs]")) == 12 - assert py_get_unit_from_dtype(np.dtype("m8[as]")) == 13 + assert py_get_unit_from_dtype(np.dtype("m8[D]")) == NpyDatetimeUnit.NPY_FR_D.value + assert py_get_unit_from_dtype(np.dtype("m8[h]")) == NpyDatetimeUnit.NPY_FR_h.value + assert py_get_unit_from_dtype(np.dtype("m8[m]")) == NpyDatetimeUnit.NPY_FR_m.value + assert py_get_unit_from_dtype(np.dtype("m8[s]")) == NpyDatetimeUnit.NPY_FR_s.value + assert py_get_unit_from_dtype(np.dtype("m8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value + assert py_get_unit_from_dtype(np.dtype("m8[us]")) == NpyDatetimeUnit.NPY_FR_us.value + assert py_get_unit_from_dtype(np.dtype("m8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value + assert py_get_unit_from_dtype(np.dtype("m8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value + assert py_get_unit_from_dtype(np.dtype("m8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value + assert py_get_unit_from_dtype(np.dtype("m8[as]")) == NpyDatetimeUnit.NPY_FR_as.value def test_td64_to_tdstruct(): val = 12454636234 # arbitrary value - res1 = py_td64_to_tdstruct(val, 10) # ns + res1 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ns.value) exp1 = { "days": 0, "hrs": 0, @@ -87,7 +88,7 @@ def test_td64_to_tdstruct(): } assert res1 == exp1 - res2 = py_td64_to_tdstruct(val, 9) # us + res2 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_us.value) exp2 = { "days": 0, "hrs": 3, @@ -102,7 +103,7 @@ def test_td64_to_tdstruct(): } assert res2 == exp2 - res3 = py_td64_to_tdstruct(val, 8) # ms + res3 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ms.value) exp3 = { "days": 144, "hrs": 3, @@ -118,7 +119,7 @@ def test_td64_to_tdstruct(): assert res3 == exp3 # Note this out of bounds for nanosecond Timedelta - res4 = py_td64_to_tdstruct(val, 7) # s + res4 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_s.value) exp4 = { "days": 144150, "hrs": 21,