Skip to content

ENH: DatetimeArray fields support non-nano #47044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/_libs/tslibs/fields.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_start_end_field(
def get_date_field(
dtindex: npt.NDArray[np.int64], # const int64_t[:]
field: str,
reso: int = ..., # NPY_DATETIMEUNIT
) -> npt.NDArray[np.int32]: ...
def get_timedelta_field(
tdindex: npt.NDArray[np.int64], # const int64_t[:]
Expand All @@ -32,6 +33,7 @@ def isleapyear_arr(
) -> npt.NDArray[np.bool_]: ...
def build_isocalendar_sarray(
dtindex: npt.NDArray[np.int64], # const int64_t[:]
reso: int = ..., # NPY_DATETIMEUNIT
) -> np.ndarray: ...
def _get_locale_names(name_type: str, locale: str | None = ...): ...

Expand Down
35 changes: 18 additions & 17 deletions pandas/_libs/tslibs/fields.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def get_start_end_field(

@cython.wraparound(False)
@cython.boundscheck(False)
def get_date_field(const int64_t[:] dtindex, str field):
def get_date_field(const int64_t[:] dtindex, str field, NPY_DATETIMEUNIT reso=NPY_FR_ns):
"""
Given a int64-based datetime index, extract the year, month, etc.,
field and return an array of these values.
Expand All @@ -348,7 +348,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dts.year
return out

Expand All @@ -359,7 +359,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dts.month
return out

Expand All @@ -370,7 +370,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dts.day
return out

Expand All @@ -381,8 +381,9 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dts.hour
# TODO: can we de-dup with period.pyx <accessor>s?
return out

elif field == 'm':
Expand All @@ -392,7 +393,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dts.min
return out

Expand All @@ -403,7 +404,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dts.sec
return out

Expand All @@ -414,7 +415,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dts.us
return out

Expand All @@ -425,7 +426,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dts.ps // 1000
return out
elif field == 'doy':
Expand All @@ -435,7 +436,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = get_day_of_year(dts.year, dts.month, dts.day)
return out

Expand All @@ -446,7 +447,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dayofweek(dts.year, dts.month, dts.day)
return out

Expand All @@ -457,7 +458,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = get_week_of_year(dts.year, dts.month, dts.day)
return out

Expand All @@ -468,7 +469,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = dts.month
out[i] = ((out[i] - 1) // 3) + 1
return out
Expand All @@ -480,11 +481,11 @@ def get_date_field(const int64_t[:] dtindex, str field):
out[i] = -1
continue

dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
out[i] = get_days_in_month(dts.year, dts.month)
return out
elif field == 'is_leap_year':
return isleapyear_arr(get_date_field(dtindex, 'Y'))
return isleapyear_arr(get_date_field(dtindex, 'Y', reso=reso))

raise ValueError(f"Field {field} not supported")

Expand Down Expand Up @@ -564,7 +565,7 @@ cpdef isleapyear_arr(ndarray years):

@cython.wraparound(False)
@cython.boundscheck(False)
def build_isocalendar_sarray(const int64_t[:] dtindex):
def build_isocalendar_sarray(const int64_t[:] dtindex, NPY_DATETIMEUNIT reso=NPY_FR_ns):
"""
Given a int64-based datetime array, return the ISO 8601 year, week, and day
as a structured array.
Expand Down Expand Up @@ -592,7 +593,7 @@ def build_isocalendar_sarray(const int64_t[:] dtindex):
if dtindex[i] == NPY_NAT:
ret_val = 0, 0, 0
else:
dt64_to_dtstruct(dtindex[i], &dts)
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
ret_val = get_iso_calendar(dts.year, dts.month, dts.day)

iso_years[i] = ret_val[0]
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def f(self):
values, field, self.freqstr, month_kw, reso=self._reso
)
else:
result = fields.get_date_field(values, field)
result = fields.get_date_field(values, field, reso=self._reso)

# these return a boolean by-definition
return result
Expand All @@ -146,7 +146,7 @@ def f(self):
result = self._maybe_mask_results(result, fill_value=None)

else:
result = fields.get_date_field(values, field)
result = fields.get_date_field(values, field, reso=self._reso)
result = self._maybe_mask_results(
result, fill_value=None, convert="float64"
)
Expand Down Expand Up @@ -1403,7 +1403,7 @@ def isocalendar(self) -> DataFrame:
from pandas import DataFrame

values = self._local_timestamps()
sarray = fields.build_isocalendar_sarray(values)
sarray = fields.build_isocalendar_sarray(values, reso=self._reso)
iso_calendar_df = DataFrame(
sarray, columns=["year", "week", "day"], dtype="UInt32"
)
Expand Down
26 changes: 25 additions & 1 deletion pandas/tests/arrays/test_datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@


class TestNonNano:
@pytest.mark.parametrize("unit,reso", [("s", 7), ("ms", 8), ("us", 9)])
@pytest.fixture(params=["s", "ms", "us"])
def unit(self, request):
return request.param

@pytest.fixture
def reso(self, unit):
# TODO: avoid hard-coding
return {"s": 7, "ms": 8, "us": 9}[unit]

@pytest.mark.xfail(reason="_box_func is not yet patched to get reso right")
def test_non_nano(self, unit, reso):
arr = np.arange(5, dtype=np.int64).view(f"M8[{unit}]")
Expand All @@ -21,6 +29,22 @@ def test_non_nano(self, unit, reso):
assert dta.dtype == arr.dtype
assert dta[0]._reso == reso

@pytest.mark.filterwarnings(
"ignore:weekofyear and week have been deprecated:FutureWarning"
)
@pytest.mark.parametrize(
"field", DatetimeArray._field_ops + DatetimeArray._bool_ops
)
def test_fields(self, unit, reso, field):
dti = pd.date_range("2016-01-01", periods=55, freq="D")
arr = np.asarray(dti).astype(f"M8[{unit}]")

dta = DatetimeArray._simple_new(arr, dtype=arr.dtype)

res = getattr(dta, field)
expected = getattr(dti._data, field)
tm.assert_numpy_array_equal(res, expected)


class TestDatetimeArrayComparisons:
# TODO: merge this into tests/arithmetic/test_datetime64 once it is
Expand Down