Skip to content

ENH: TDA fields support non-nano #47278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas/_libs/tslibs/fields.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def get_date_field(
def get_timedelta_field(
tdindex: npt.NDArray[np.int64], # const int64_t[:]
field: str,
reso: int = ..., # NPY_DATETIMEUNIT
) -> npt.NDArray[np.int32]: ...
def isleapyear_arr(
years: np.ndarray,
Expand Down
16 changes: 10 additions & 6 deletions pandas/_libs/tslibs/fields.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ from pandas._libs.tslibs.np_datetime cimport (
get_unit_from_dtype,
npy_datetimestruct,
pandas_datetime_to_datetimestruct,
pandas_timedelta_to_timedeltastruct,
pandas_timedeltastruct,
td64_to_tdstruct,
)


Expand Down Expand Up @@ -491,7 +491,11 @@ def get_date_field(const int64_t[:] dtindex, str field, NPY_DATETIMEUNIT reso=NP

@cython.wraparound(False)
@cython.boundscheck(False)
def get_timedelta_field(const int64_t[:] tdindex, str field):
def get_timedelta_field(
const int64_t[:] tdindex,
str field,
NPY_DATETIMEUNIT reso=NPY_FR_ns,
):
"""
Given a int64-based timedelta index, extract the days, hrs, sec.,
field and return an array of these values.
Expand All @@ -510,7 +514,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field):
out[i] = -1
continue

td64_to_tdstruct(tdindex[i], &tds)
pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds)
out[i] = tds.days
return out

Expand All @@ -521,7 +525,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field):
out[i] = -1
continue

td64_to_tdstruct(tdindex[i], &tds)
pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds)
out[i] = tds.seconds
return out

Expand All @@ -532,7 +536,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field):
out[i] = -1
continue

td64_to_tdstruct(tdindex[i], &tds)
pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds)
out[i] = tds.microseconds
return out

Expand All @@ -543,7 +547,7 @@ def get_timedelta_field(const int64_t[:] tdindex, str field):
out[i] = -1
continue

td64_to_tdstruct(tdindex[i], &tds)
pandas_timedelta_to_timedeltastruct(tdindex[i], reso, &tds)
out[i] = tds.nanoseconds
return out

Expand Down
1 change: 0 additions & 1 deletion pandas/_libs/tslibs/np_datetime.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ cdef check_dts_bounds(npy_datetimestruct *dts, NPY_DATETIMEUNIT unit=?)

cdef int64_t dtstruct_to_dt64(npy_datetimestruct* dts) nogil
cdef void dt64_to_dtstruct(int64_t dt64, npy_datetimestruct* out) nogil
cdef void td64_to_tdstruct(int64_t td64, pandas_timedeltastruct* out) nogil

cdef int64_t pydatetime_to_dt64(datetime val, npy_datetimestruct *dts)
cdef int64_t pydate_to_dt64(date val, npy_datetimestruct *dts)
Expand Down
8 changes: 0 additions & 8 deletions pandas/_libs/tslibs/np_datetime.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,6 @@ cdef inline void dt64_to_dtstruct(int64_t dt64,
return


cdef inline void td64_to_tdstruct(int64_t td64,
pandas_timedeltastruct* out) nogil:
"""Convenience function to call pandas_timedelta_to_timedeltastruct
with the by-far-most-common frequency NPY_FR_ns"""
pandas_timedelta_to_timedeltastruct(td64, NPY_FR_ns, out)
return


# just exposed for testing at the moment
def py_td64_to_tdstruct(int64_t td64, NPY_DATETIMEUNIT unit):
cdef:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
def _field_accessor(name: str, alias: str, docstring: str):
def f(self) -> np.ndarray:
values = self.asi8
result = get_timedelta_field(values, alias)
result = get_timedelta_field(values, alias, reso=self._reso)
if self._hasna:
result = self._maybe_mask_results(
result, fill_value=None, convert="float64"
Expand Down
30 changes: 29 additions & 1 deletion pandas/tests/arrays/test_timedeltas.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,49 @@
import numpy as np
import pytest

from pandas._libs.tslibs.dtypes import NpyDatetimeUnit

import pandas as pd
from pandas import Timedelta
import pandas._testing as tm
from pandas.core.arrays import TimedeltaArray


class TestNonNano:
@pytest.mark.parametrize("unit,reso", [("s", 7), ("ms", 8), ("us", 9)])
@pytest.fixture(params=["s", "ms", "us"])
def unit(self, request):
return request.param

@pytest.fixture
def reso(self, unit):
if unit == "s":
return NpyDatetimeUnit.NPY_FR_s.value
elif unit == "ms":
return NpyDatetimeUnit.NPY_FR_ms.value
elif unit == "us":
return NpyDatetimeUnit.NPY_FR_us.value
else:
raise NotImplementedError(unit)

def test_non_nano(self, unit, reso):
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)

assert tda.dtype == arr.dtype
assert tda[0]._reso == reso

@pytest.mark.parametrize("field", TimedeltaArray._field_ops)
def test_fields(self, unit, reso, field):
arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]")
tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype)

as_nano = arr.astype("m8[ns]")
tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype)

result = getattr(tda, field)
expected = getattr(tda_nano, field)
tm.assert_numpy_array_equal(result, expected)


class TestTimedeltaArray:
@pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"])
Expand Down
61 changes: 31 additions & 30 deletions pandas/tests/tslibs/test_np_datetime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from pandas._libs.tslibs.dtypes import NpyDatetimeUnit
from pandas._libs.tslibs.np_datetime import (
OutOfBoundsDatetime,
OutOfBoundsTimedelta,
Expand Down Expand Up @@ -37,42 +38,42 @@ def test_is_unitless():

def test_get_unit_from_dtype():
# datetime64
assert py_get_unit_from_dtype(np.dtype("M8[Y]")) == 0
assert py_get_unit_from_dtype(np.dtype("M8[M]")) == 1
assert py_get_unit_from_dtype(np.dtype("M8[W]")) == 2
assert py_get_unit_from_dtype(np.dtype("M8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value
assert py_get_unit_from_dtype(np.dtype("M8[M]")) == NpyDatetimeUnit.NPY_FR_M.value
assert py_get_unit_from_dtype(np.dtype("M8[W]")) == NpyDatetimeUnit.NPY_FR_W.value
# B has been deprecated and removed -> no 3
assert py_get_unit_from_dtype(np.dtype("M8[D]")) == 4
assert py_get_unit_from_dtype(np.dtype("M8[h]")) == 5
assert py_get_unit_from_dtype(np.dtype("M8[m]")) == 6
assert py_get_unit_from_dtype(np.dtype("M8[s]")) == 7
assert py_get_unit_from_dtype(np.dtype("M8[ms]")) == 8
assert py_get_unit_from_dtype(np.dtype("M8[us]")) == 9
assert py_get_unit_from_dtype(np.dtype("M8[ns]")) == 10
assert py_get_unit_from_dtype(np.dtype("M8[ps]")) == 11
assert py_get_unit_from_dtype(np.dtype("M8[fs]")) == 12
assert py_get_unit_from_dtype(np.dtype("M8[as]")) == 13
assert py_get_unit_from_dtype(np.dtype("M8[D]")) == NpyDatetimeUnit.NPY_FR_D.value
assert py_get_unit_from_dtype(np.dtype("M8[h]")) == NpyDatetimeUnit.NPY_FR_h.value
assert py_get_unit_from_dtype(np.dtype("M8[m]")) == NpyDatetimeUnit.NPY_FR_m.value
assert py_get_unit_from_dtype(np.dtype("M8[s]")) == NpyDatetimeUnit.NPY_FR_s.value
assert py_get_unit_from_dtype(np.dtype("M8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value
assert py_get_unit_from_dtype(np.dtype("M8[us]")) == NpyDatetimeUnit.NPY_FR_us.value
assert py_get_unit_from_dtype(np.dtype("M8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value
assert py_get_unit_from_dtype(np.dtype("M8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value
assert py_get_unit_from_dtype(np.dtype("M8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value
assert py_get_unit_from_dtype(np.dtype("M8[as]")) == NpyDatetimeUnit.NPY_FR_as.value

# timedelta64
assert py_get_unit_from_dtype(np.dtype("m8[Y]")) == 0
assert py_get_unit_from_dtype(np.dtype("m8[M]")) == 1
assert py_get_unit_from_dtype(np.dtype("m8[W]")) == 2
assert py_get_unit_from_dtype(np.dtype("m8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value
assert py_get_unit_from_dtype(np.dtype("m8[M]")) == NpyDatetimeUnit.NPY_FR_M.value
assert py_get_unit_from_dtype(np.dtype("m8[W]")) == NpyDatetimeUnit.NPY_FR_W.value
# B has been deprecated and removed -> no 3
assert py_get_unit_from_dtype(np.dtype("m8[D]")) == 4
assert py_get_unit_from_dtype(np.dtype("m8[h]")) == 5
assert py_get_unit_from_dtype(np.dtype("m8[m]")) == 6
assert py_get_unit_from_dtype(np.dtype("m8[s]")) == 7
assert py_get_unit_from_dtype(np.dtype("m8[ms]")) == 8
assert py_get_unit_from_dtype(np.dtype("m8[us]")) == 9
assert py_get_unit_from_dtype(np.dtype("m8[ns]")) == 10
assert py_get_unit_from_dtype(np.dtype("m8[ps]")) == 11
assert py_get_unit_from_dtype(np.dtype("m8[fs]")) == 12
assert py_get_unit_from_dtype(np.dtype("m8[as]")) == 13
assert py_get_unit_from_dtype(np.dtype("m8[D]")) == NpyDatetimeUnit.NPY_FR_D.value
assert py_get_unit_from_dtype(np.dtype("m8[h]")) == NpyDatetimeUnit.NPY_FR_h.value
assert py_get_unit_from_dtype(np.dtype("m8[m]")) == NpyDatetimeUnit.NPY_FR_m.value
assert py_get_unit_from_dtype(np.dtype("m8[s]")) == NpyDatetimeUnit.NPY_FR_s.value
assert py_get_unit_from_dtype(np.dtype("m8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value
assert py_get_unit_from_dtype(np.dtype("m8[us]")) == NpyDatetimeUnit.NPY_FR_us.value
assert py_get_unit_from_dtype(np.dtype("m8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value
assert py_get_unit_from_dtype(np.dtype("m8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value
assert py_get_unit_from_dtype(np.dtype("m8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value
assert py_get_unit_from_dtype(np.dtype("m8[as]")) == NpyDatetimeUnit.NPY_FR_as.value


def test_td64_to_tdstruct():
val = 12454636234 # arbitrary value

res1 = py_td64_to_tdstruct(val, 10) # ns
res1 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ns.value)
exp1 = {
"days": 0,
"hrs": 0,
Expand All @@ -87,7 +88,7 @@ def test_td64_to_tdstruct():
}
assert res1 == exp1

res2 = py_td64_to_tdstruct(val, 9) # us
res2 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_us.value)
exp2 = {
"days": 0,
"hrs": 3,
Expand All @@ -102,7 +103,7 @@ def test_td64_to_tdstruct():
}
assert res2 == exp2

res3 = py_td64_to_tdstruct(val, 8) # ms
res3 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ms.value)
exp3 = {
"days": 144,
"hrs": 3,
Expand All @@ -118,7 +119,7 @@ def test_td64_to_tdstruct():
assert res3 == exp3

# Note this out of bounds for nanosecond Timedelta
res4 = py_td64_to_tdstruct(val, 7) # s
res4 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_s.value)
exp4 = {
"days": 144150,
"hrs": 21,
Expand Down