Skip to content

Commit d99bce5

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
ENH: DatetimeArray fields support non-nano (pandas-dev#47044)
1 parent ecfd216 commit d99bce5

File tree

4 files changed

+48
-21
lines changed

4 files changed

+48
-21
lines changed

pandas/_libs/tslibs/fields.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def get_start_end_field(
2222
def get_date_field(
2323
dtindex: npt.NDArray[np.int64], # const int64_t[:]
2424
field: str,
25+
reso: int = ..., # NPY_DATETIMEUNIT
2526
) -> npt.NDArray[np.int32]: ...
2627
def get_timedelta_field(
2728
tdindex: npt.NDArray[np.int64], # const int64_t[:]
@@ -32,6 +33,7 @@ def isleapyear_arr(
3233
) -> npt.NDArray[np.bool_]: ...
3334
def build_isocalendar_sarray(
3435
dtindex: npt.NDArray[np.int64], # const int64_t[:]
36+
reso: int = ..., # NPY_DATETIMEUNIT
3537
) -> np.ndarray: ...
3638
def _get_locale_names(name_type: str, locale: str | None = ...): ...
3739

pandas/_libs/tslibs/fields.pyx

+18-17
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def get_start_end_field(
329329

330330
@cython.wraparound(False)
331331
@cython.boundscheck(False)
332-
def get_date_field(const int64_t[:] dtindex, str field):
332+
def get_date_field(const int64_t[:] dtindex, str field, NPY_DATETIMEUNIT reso=NPY_FR_ns):
333333
"""
334334
Given a int64-based datetime index, extract the year, month, etc.,
335335
field and return an array of these values.
@@ -348,7 +348,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
348348
out[i] = -1
349349
continue
350350

351-
dt64_to_dtstruct(dtindex[i], &dts)
351+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
352352
out[i] = dts.year
353353
return out
354354

@@ -359,7 +359,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
359359
out[i] = -1
360360
continue
361361

362-
dt64_to_dtstruct(dtindex[i], &dts)
362+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
363363
out[i] = dts.month
364364
return out
365365

@@ -370,7 +370,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
370370
out[i] = -1
371371
continue
372372

373-
dt64_to_dtstruct(dtindex[i], &dts)
373+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
374374
out[i] = dts.day
375375
return out
376376

@@ -381,8 +381,9 @@ def get_date_field(const int64_t[:] dtindex, str field):
381381
out[i] = -1
382382
continue
383383

384-
dt64_to_dtstruct(dtindex[i], &dts)
384+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
385385
out[i] = dts.hour
386+
# TODO: can we de-dup with period.pyx <accessor>s?
386387
return out
387388

388389
elif field == 'm':
@@ -392,7 +393,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
392393
out[i] = -1
393394
continue
394395

395-
dt64_to_dtstruct(dtindex[i], &dts)
396+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
396397
out[i] = dts.min
397398
return out
398399

@@ -403,7 +404,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
403404
out[i] = -1
404405
continue
405406

406-
dt64_to_dtstruct(dtindex[i], &dts)
407+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
407408
out[i] = dts.sec
408409
return out
409410

@@ -414,7 +415,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
414415
out[i] = -1
415416
continue
416417

417-
dt64_to_dtstruct(dtindex[i], &dts)
418+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
418419
out[i] = dts.us
419420
return out
420421

@@ -425,7 +426,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
425426
out[i] = -1
426427
continue
427428

428-
dt64_to_dtstruct(dtindex[i], &dts)
429+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
429430
out[i] = dts.ps // 1000
430431
return out
431432
elif field == 'doy':
@@ -435,7 +436,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
435436
out[i] = -1
436437
continue
437438

438-
dt64_to_dtstruct(dtindex[i], &dts)
439+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
439440
out[i] = get_day_of_year(dts.year, dts.month, dts.day)
440441
return out
441442

@@ -446,7 +447,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
446447
out[i] = -1
447448
continue
448449

449-
dt64_to_dtstruct(dtindex[i], &dts)
450+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
450451
out[i] = dayofweek(dts.year, dts.month, dts.day)
451452
return out
452453

@@ -457,7 +458,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
457458
out[i] = -1
458459
continue
459460

460-
dt64_to_dtstruct(dtindex[i], &dts)
461+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
461462
out[i] = get_week_of_year(dts.year, dts.month, dts.day)
462463
return out
463464

@@ -468,7 +469,7 @@ def get_date_field(const int64_t[:] dtindex, str field):
468469
out[i] = -1
469470
continue
470471

471-
dt64_to_dtstruct(dtindex[i], &dts)
472+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
472473
out[i] = dts.month
473474
out[i] = ((out[i] - 1) // 3) + 1
474475
return out
@@ -480,11 +481,11 @@ def get_date_field(const int64_t[:] dtindex, str field):
480481
out[i] = -1
481482
continue
482483

483-
dt64_to_dtstruct(dtindex[i], &dts)
484+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
484485
out[i] = get_days_in_month(dts.year, dts.month)
485486
return out
486487
elif field == 'is_leap_year':
487-
return isleapyear_arr(get_date_field(dtindex, 'Y'))
488+
return isleapyear_arr(get_date_field(dtindex, 'Y', reso=reso))
488489

489490
raise ValueError(f"Field {field} not supported")
490491

@@ -564,7 +565,7 @@ cpdef isleapyear_arr(ndarray years):
564565

565566
@cython.wraparound(False)
566567
@cython.boundscheck(False)
567-
def build_isocalendar_sarray(const int64_t[:] dtindex):
568+
def build_isocalendar_sarray(const int64_t[:] dtindex, NPY_DATETIMEUNIT reso=NPY_FR_ns):
568569
"""
569570
Given a int64-based datetime array, return the ISO 8601 year, week, and day
570571
as a structured array.
@@ -592,7 +593,7 @@ def build_isocalendar_sarray(const int64_t[:] dtindex):
592593
if dtindex[i] == NPY_NAT:
593594
ret_val = 0, 0, 0
594595
else:
595-
dt64_to_dtstruct(dtindex[i], &dts)
596+
pandas_datetime_to_datetimestruct(dtindex[i], reso, &dts)
596597
ret_val = get_iso_calendar(dts.year, dts.month, dts.day)
597598

598599
iso_years[i] = ret_val[0]

pandas/core/arrays/datetimes.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def f(self):
136136
values, field, self.freqstr, month_kw, reso=self._reso
137137
)
138138
else:
139-
result = fields.get_date_field(values, field)
139+
result = fields.get_date_field(values, field, reso=self._reso)
140140

141141
# these return a boolean by-definition
142142
return result
@@ -146,7 +146,7 @@ def f(self):
146146
result = self._maybe_mask_results(result, fill_value=None)
147147

148148
else:
149-
result = fields.get_date_field(values, field)
149+
result = fields.get_date_field(values, field, reso=self._reso)
150150
result = self._maybe_mask_results(
151151
result, fill_value=None, convert="float64"
152152
)
@@ -1403,7 +1403,7 @@ def isocalendar(self) -> DataFrame:
14031403
from pandas import DataFrame
14041404

14051405
values = self._local_timestamps()
1406-
sarray = fields.build_isocalendar_sarray(values)
1406+
sarray = fields.build_isocalendar_sarray(values, reso=self._reso)
14071407
iso_calendar_df = DataFrame(
14081408
sarray, columns=["year", "week", "day"], dtype="UInt32"
14091409
)

pandas/tests/arrays/test_datetimes.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@
1212

1313

1414
class TestNonNano:
15-
@pytest.mark.parametrize("unit,reso", [("s", 7), ("ms", 8), ("us", 9)])
15+
@pytest.fixture(params=["s", "ms", "us"])
16+
def unit(self, request):
17+
return request.param
18+
19+
@pytest.fixture
20+
def reso(self, unit):
21+
# TODO: avoid hard-coding
22+
return {"s": 7, "ms": 8, "us": 9}[unit]
23+
1624
@pytest.mark.xfail(reason="_box_func is not yet patched to get reso right")
1725
def test_non_nano(self, unit, reso):
1826
arr = np.arange(5, dtype=np.int64).view(f"M8[{unit}]")
@@ -21,6 +29,22 @@ def test_non_nano(self, unit, reso):
2129
assert dta.dtype == arr.dtype
2230
assert dta[0]._reso == reso
2331

32+
@pytest.mark.filterwarnings(
33+
"ignore:weekofyear and week have been deprecated:FutureWarning"
34+
)
35+
@pytest.mark.parametrize(
36+
"field", DatetimeArray._field_ops + DatetimeArray._bool_ops
37+
)
38+
def test_fields(self, unit, reso, field):
39+
dti = pd.date_range("2016-01-01", periods=55, freq="D")
40+
arr = np.asarray(dti).astype(f"M8[{unit}]")
41+
42+
dta = DatetimeArray._simple_new(arr, dtype=arr.dtype)
43+
44+
res = getattr(dta, field)
45+
expected = getattr(dti._data, field)
46+
tm.assert_numpy_array_equal(res, expected)
47+
2448

2549
class TestDatetimeArrayComparisons:
2650
# TODO: merge this into tests/arithmetic/test_datetime64 once it is

0 commit comments

Comments
 (0)