From f1e9c50ad0676e0ac97d4d66f781ed9cf2f02a5b Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 17 May 2022 15:12:21 -0700 Subject: [PATCH] ENH: DatetimeArray fields support non-nano --- pandas/_libs/tslibs/fields.pyi | 2 ++ pandas/_libs/tslibs/fields.pyx | 35 ++++++++++++++------------- pandas/core/arrays/datetimes.py | 6 ++--- pandas/tests/arrays/test_datetimes.py | 26 +++++++++++++++++++- 4 files changed, 48 insertions(+), 21 deletions(-) diff --git a/pandas/_libs/tslibs/fields.pyi b/pandas/_libs/tslibs/fields.pyi index e404eadf13657..b1d9e0342f81e 100644 --- a/pandas/_libs/tslibs/fields.pyi +++ b/pandas/_libs/tslibs/fields.pyi @@ -22,6 +22,7 @@ def get_start_end_field( def get_date_field( dtindex: npt.NDArray[np.int64], # const int64_t[:] field: str, + reso: int = ..., # NPY_DATETIMEUNIT ) -> npt.NDArray[np.int32]: ... def get_timedelta_field( tdindex: npt.NDArray[np.int64], # const int64_t[:] @@ -32,6 +33,7 @@ def isleapyear_arr( ) -> npt.NDArray[np.bool_]: ... def build_isocalendar_sarray( dtindex: npt.NDArray[np.int64], # const int64_t[:] + reso: int = ..., # NPY_DATETIMEUNIT ) -> np.ndarray: ... def _get_locale_names(name_type: str, locale: str | None = ...): ... diff --git a/pandas/_libs/tslibs/fields.pyx b/pandas/_libs/tslibs/fields.pyx index 57d4c27b3337d..5865b8c6877b0 100644 --- a/pandas/_libs/tslibs/fields.pyx +++ b/pandas/_libs/tslibs/fields.pyx @@ -329,7 +329,7 @@ def get_start_end_field( @cython.wraparound(False) @cython.boundscheck(False) -def get_date_field(const int64_t[:] dtindex, str field): +def get_date_field(const int64_t[:] dtindex, str field, NPY_DATETIMEUNIT reso=NPY_FR_ns): """ Given a int64-based datetime index, extract the year, month, etc., field and return an array of these values. @@ -348,7 +348,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dts.year return out @@ -359,7 +359,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dts.month return out @@ -370,7 +370,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dts.day return out @@ -381,8 +381,9 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dts.hour + # TODO: can we de-dup with period.pyx s? return out elif field == 'm': @@ -392,7 +393,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dts.min return out @@ -403,7 +404,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dts.sec return out @@ -414,7 +415,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dts.us return out @@ -425,7 +426,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dts.ps // 1000 return out elif field == 'doy': @@ -435,7 +436,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = get_day_of_year(dts.year, dts.month, dts.day) return out @@ -446,7 +447,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dayofweek(dts.year, dts.month, dts.day) return out @@ -457,7 +458,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = get_week_of_year(dts.year, dts.month, dts.day) return out @@ -468,7 +469,7 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = dts.month out[i] = ((out[i] - 1) // 3) + 1 return out @@ -480,11 +481,11 @@ def get_date_field(const int64_t[:] dtindex, str field): out[i] = -1 continue - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) out[i] = get_days_in_month(dts.year, dts.month) return out elif field == 'is_leap_year': - return isleapyear_arr(get_date_field(dtindex, 'Y')) + return isleapyear_arr(get_date_field(dtindex, 'Y', reso=reso)) raise ValueError(f"Field {field} not supported") @@ -564,7 +565,7 @@ cpdef isleapyear_arr(ndarray years): @cython.wraparound(False) @cython.boundscheck(False) -def build_isocalendar_sarray(const int64_t[:] dtindex): +def build_isocalendar_sarray(const int64_t[:] dtindex, NPY_DATETIMEUNIT reso=NPY_FR_ns): """ Given a int64-based datetime array, return the ISO 8601 year, week, and day as a structured array. @@ -592,7 +593,7 @@ def build_isocalendar_sarray(const int64_t[:] dtindex): if dtindex[i] == NPY_NAT: ret_val = 0, 0, 0 else: - dt64_to_dtstruct(dtindex[i], &dts) + pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts) ret_val = get_iso_calendar(dts.year, dts.month, dts.day) iso_years[i] = ret_val[0] diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index 6f984727f4f6d..dadfad394b903 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -136,7 +136,7 @@ def f(self): values, field, self.freqstr, month_kw, reso=self._reso ) else: - result = fields.get_date_field(values, field) + result = fields.get_date_field(values, field, reso=self._reso) # these return a boolean by-definition return result @@ -146,7 +146,7 @@ def f(self): result = self._maybe_mask_results(result, fill_value=None) else: - result = fields.get_date_field(values, field) + result = fields.get_date_field(values, field, reso=self._reso) result = self._maybe_mask_results( result, fill_value=None, convert="float64" ) @@ -1403,7 +1403,7 @@ def isocalendar(self) -> DataFrame: from pandas import DataFrame values = self._local_timestamps() - sarray = fields.build_isocalendar_sarray(values) + sarray = fields.build_isocalendar_sarray(values, reso=self._reso) iso_calendar_df = DataFrame( sarray, columns=["year", "week", "day"], dtype="UInt32" ) diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index 8eb5cc2dd82f6..897528cf18122 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -12,7 +12,15 @@ class TestNonNano: - @pytest.mark.parametrize("unit,reso", [("s", 7), ("ms", 8), ("us", 9)]) + @pytest.fixture(params=["s", "ms", "us"]) + def unit(self, request): + return request.param + + @pytest.fixture + def reso(self, unit): + # TODO: avoid hard-coding + return {"s": 7, "ms": 8, "us": 9}[unit] + @pytest.mark.xfail(reason="_box_func is not yet patched to get reso right") def test_non_nano(self, unit, reso): arr = np.arange(5, dtype=np.int64).view(f"M8[{unit}]") @@ -21,6 +29,22 @@ def test_non_nano(self, unit, reso): assert dta.dtype == arr.dtype assert dta[0]._reso == reso + @pytest.mark.filterwarnings( + "ignore:weekofyear and week have been deprecated:FutureWarning" + ) + @pytest.mark.parametrize( + "field", DatetimeArray._field_ops + DatetimeArray._bool_ops + ) + def test_fields(self, unit, reso, field): + dti = pd.date_range("2016-01-01", periods=55, freq="D") + arr = np.asarray(dti).astype(f"M8[{unit}]") + + dta = DatetimeArray._simple_new(arr, dtype=arr.dtype) + + res = getattr(dta, field) + expected = getattr(dti._data, field) + tm.assert_numpy_array_equal(res, expected) + class TestDatetimeArrayComparisons: # TODO: merge this into tests/arithmetic/test_datetime64 once it is