Skip to content

Commit d1b2c44

Browse files
authored
ENH/POC: infer resolution in array_strptime (#55778)
* ENH/POC: infer resolution in array_strptime * creso_changed->creso_ever_changed
1 parent 589d97d commit d1b2c44

File tree

3 files changed

+152
-15
lines changed

3 files changed

+152
-15
lines changed

pandas/_libs/tslibs/strptime.pxd

+3
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,8 @@ cdef class DatetimeParseState:
1414
cdef:
1515
bint found_tz
1616
bint found_naive
17+
bint creso_ever_changed
18+
NPY_DATETIMEUNIT creso
1719

1820
cdef tzinfo process_datetime(self, datetime dt, tzinfo tz, bint utc_convert)
21+
cdef bint update_creso(self, NPY_DATETIMEUNIT item_reso) noexcept

pandas/_libs/tslibs/strptime.pyx

+88-15
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ from numpy cimport (
4949

5050
from pandas._libs.missing cimport checknull_with_nat_and_na
5151
from pandas._libs.tslibs.conversion cimport get_datetime64_nanos
52+
from pandas._libs.tslibs.dtypes cimport (
53+
get_supported_reso,
54+
npy_unit_to_abbrev,
55+
)
5256
from pandas._libs.tslibs.nattype cimport (
5357
NPY_NAT,
5458
c_nat_strings as nat_strings,
@@ -57,6 +61,7 @@ from pandas._libs.tslibs.np_datetime cimport (
5761
NPY_DATETIMEUNIT,
5862
NPY_FR_ns,
5963
check_dts_bounds,
64+
get_datetime64_unit,
6065
import_pandas_datetime,
6166
npy_datetimestruct,
6267
npy_datetimestruct_to_datetime,
@@ -232,9 +237,21 @@ cdef _get_format_regex(str fmt):
232237

233238

234239
cdef class DatetimeParseState:
235-
def __cinit__(self):
240+
def __cinit__(self, NPY_DATETIMEUNIT creso=NPY_DATETIMEUNIT.NPY_FR_ns):
236241
self.found_tz = False
237242
self.found_naive = False
243+
self.creso = creso
244+
self.creso_ever_changed = False
245+
246+
cdef bint update_creso(self, NPY_DATETIMEUNIT item_reso) noexcept:
247+
# Return a bool indicating whether we bumped to a higher resolution
248+
if self.creso == NPY_DATETIMEUNIT.NPY_FR_GENERIC:
249+
self.creso = item_reso
250+
elif item_reso > self.creso:
251+
self.creso = item_reso
252+
self.creso_ever_changed = True
253+
return True
254+
return False
238255

239256
cdef tzinfo process_datetime(self, datetime dt, tzinfo tz, bint utc_convert):
240257
if dt.tzinfo is not None:
@@ -268,6 +285,7 @@ def array_strptime(
268285
bint exact=True,
269286
errors="raise",
270287
bint utc=False,
288+
NPY_DATETIMEUNIT creso=NPY_FR_ns,
271289
):
272290
"""
273291
Calculates the datetime structs represented by the passed array of strings
@@ -278,6 +296,8 @@ def array_strptime(
278296
fmt : string-like regex
279297
exact : matches must be exact if True, search if False
280298
errors : string specifying error handling, {'raise', 'ignore', 'coerce'}
299+
creso : NPY_DATETIMEUNIT, default NPY_FR_ns
300+
Set to NPY_FR_GENERIC to infer a resolution.
281301
"""
282302

283303
cdef:
@@ -291,17 +311,22 @@ def array_strptime(
291311
bint is_coerce = errors=="coerce"
292312
tzinfo tz_out = None
293313
bint iso_format = format_is_iso(fmt)
294-
NPY_DATETIMEUNIT out_bestunit
314+
NPY_DATETIMEUNIT out_bestunit, item_reso
295315
int out_local = 0, out_tzoffset = 0
296316
bint string_to_dts_succeeded = 0
297-
DatetimeParseState state = DatetimeParseState()
317+
bint infer_reso = creso == NPY_DATETIMEUNIT.NPY_FR_GENERIC
318+
DatetimeParseState state = DatetimeParseState(creso)
298319

299320
assert is_raise or is_ignore or is_coerce
300321

301322
_validate_fmt(fmt)
302323
format_regex, locale_time = _get_format_regex(fmt)
303324

304-
result = np.empty(n, dtype="M8[ns]")
325+
if infer_reso:
326+
abbrev = "ns"
327+
else:
328+
abbrev = npy_unit_to_abbrev(creso)
329+
result = np.empty(n, dtype=f"M8[{abbrev}]")
305330
iresult = result.view("i8")
306331
result_timezone = np.empty(n, dtype="object")
307332

@@ -318,20 +343,32 @@ def array_strptime(
318343
iresult[i] = NPY_NAT
319344
continue
320345
elif PyDateTime_Check(val):
346+
if isinstance(val, _Timestamp):
347+
item_reso = val._creso
348+
else:
349+
item_reso = NPY_DATETIMEUNIT.NPY_FR_us
350+
state.update_creso(item_reso)
321351
tz_out = state.process_datetime(val, tz_out, utc)
322352
if isinstance(val, _Timestamp):
323-
iresult[i] = val.tz_localize(None).as_unit("ns")._value
353+
val = (<_Timestamp>val)._as_creso(state.creso)
354+
iresult[i] = val.tz_localize(None)._value
324355
else:
325-
iresult[i] = pydatetime_to_dt64(val.replace(tzinfo=None), &dts)
326-
check_dts_bounds(&dts)
356+
iresult[i] = pydatetime_to_dt64(
357+
val.replace(tzinfo=None), &dts, reso=state.creso
358+
)
359+
check_dts_bounds(&dts, state.creso)
327360
result_timezone[i] = val.tzinfo
328361
continue
329362
elif PyDate_Check(val):
330-
iresult[i] = pydate_to_dt64(val, &dts)
331-
check_dts_bounds(&dts)
363+
item_reso = NPY_DATETIMEUNIT.NPY_FR_s
364+
state.update_creso(item_reso)
365+
iresult[i] = pydate_to_dt64(val, &dts, reso=state.creso)
366+
check_dts_bounds(&dts, state.creso)
332367
continue
333368
elif is_datetime64_object(val):
334-
iresult[i] = get_datetime64_nanos(val, NPY_FR_ns)
369+
item_reso = get_supported_reso(get_datetime64_unit(val))
370+
state.update_creso(item_reso)
371+
iresult[i] = get_datetime64_nanos(val, state.creso)
335372
continue
336373
elif (
337374
(is_integer_object(val) or is_float_object(val))
@@ -355,7 +392,9 @@ def array_strptime(
355392
if string_to_dts_succeeded:
356393
# No error reported by string_to_dts, pick back up
357394
# where we left off
358-
value = npy_datetimestruct_to_datetime(NPY_FR_ns, &dts)
395+
item_reso = get_supported_reso(out_bestunit)
396+
state.update_creso(item_reso)
397+
value = npy_datetimestruct_to_datetime(state.creso, &dts)
359398
if out_local == 1:
360399
# Store the out_tzoffset in seconds
361400
# since we store the total_seconds of
@@ -368,7 +407,9 @@ def array_strptime(
368407
check_dts_bounds(&dts)
369408
continue
370409

371-
if parse_today_now(val, &iresult[i], utc, NPY_FR_ns):
410+
if parse_today_now(val, &iresult[i], utc, state.creso):
411+
item_reso = NPY_DATETIMEUNIT.NPY_FR_us
412+
state.update_creso(item_reso)
372413
continue
373414

374415
# Some ISO formats can't be parsed by string_to_dts
@@ -380,9 +421,10 @@ def array_strptime(
380421
raise ValueError(f"Time data {val} is not ISO8601 format")
381422

382423
tz = _parse_with_format(
383-
val, fmt, exact, format_regex, locale_time, &dts
424+
val, fmt, exact, format_regex, locale_time, &dts, &item_reso
384425
)
385-
iresult[i] = npy_datetimestruct_to_datetime(NPY_FR_ns, &dts)
426+
state.update_creso(item_reso)
427+
iresult[i] = npy_datetimestruct_to_datetime(state.creso, &dts)
386428
check_dts_bounds(&dts)
387429
result_timezone[i] = tz
388430

@@ -403,11 +445,34 @@ def array_strptime(
403445
raise
404446
return values, []
405447

448+
if infer_reso:
449+
if state.creso_ever_changed:
450+
# We encountered mismatched resolutions, need to re-parse with
451+
# the correct one.
452+
return array_strptime(
453+
values,
454+
fmt=fmt,
455+
exact=exact,
456+
errors=errors,
457+
utc=utc,
458+
creso=state.creso,
459+
)
460+
461+
# Otherwise we can use the single reso that we encountered and avoid
462+
# a second pass.
463+
abbrev = npy_unit_to_abbrev(state.creso)
464+
result = iresult.base.view(f"M8[{abbrev}]")
406465
return result, result_timezone.base
407466

408467

409468
cdef tzinfo _parse_with_format(
410-
str val, str fmt, bint exact, format_regex, locale_time, npy_datetimestruct* dts
469+
str val,
470+
str fmt,
471+
bint exact,
472+
format_regex,
473+
locale_time,
474+
npy_datetimestruct* dts,
475+
NPY_DATETIMEUNIT* item_reso,
411476
):
412477
# Based on https://github.com/python/cpython/blob/main/Lib/_strptime.py#L293
413478
cdef:
@@ -441,6 +506,8 @@ cdef tzinfo _parse_with_format(
441506
f"time data \"{val}\" doesn't match format \"{fmt}\""
442507
)
443508

509+
item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_s
510+
444511
iso_year = -1
445512
year = 1900
446513
month = day = 1
@@ -527,6 +594,12 @@ cdef tzinfo _parse_with_format(
527594
elif parse_code == 10:
528595
# e.g. val='10:10:10.100'; fmt='%H:%M:%S.%f'
529596
s = found_dict["f"]
597+
if len(s) <= 3:
598+
item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_ms
599+
elif len(s) <= 6:
600+
item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_us
601+
else:
602+
item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_ns
530603
# Pad to always return nanoseconds
531604
s += "0" * (9 - len(s))
532605
us = long(s)

pandas/tests/tslibs/test_strptime.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from datetime import (
2+
datetime,
3+
timezone,
4+
)
5+
6+
import numpy as np
7+
import pytest
8+
9+
from pandas._libs.tslibs.dtypes import NpyDatetimeUnit
10+
from pandas._libs.tslibs.strptime import array_strptime
11+
12+
from pandas import Timestamp
13+
import pandas._testing as tm
14+
15+
creso_infer = NpyDatetimeUnit.NPY_FR_GENERIC.value
16+
17+
18+
class TestArrayStrptimeResolutionInference:
19+
@pytest.mark.parametrize("tz", [None, timezone.utc])
20+
def test_array_strptime_resolution_inference_homogeneous_strings(self, tz):
21+
dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz)
22+
23+
fmt = "%Y-%m-%d %H:%M:%S"
24+
dtstr = dt.strftime(fmt)
25+
arr = np.array([dtstr] * 3, dtype=object)
26+
expected = np.array([dt.replace(tzinfo=None)] * 3, dtype="M8[s]")
27+
28+
res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
29+
tm.assert_numpy_array_equal(res, expected)
30+
31+
fmt = "%Y-%m-%d %H:%M:%S.%f"
32+
dtstr = dt.strftime(fmt)
33+
arr = np.array([dtstr] * 3, dtype=object)
34+
expected = np.array([dt.replace(tzinfo=None)] * 3, dtype="M8[us]")
35+
36+
res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
37+
tm.assert_numpy_array_equal(res, expected)
38+
39+
fmt = "ISO8601"
40+
res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
41+
tm.assert_numpy_array_equal(res, expected)
42+
43+
@pytest.mark.parametrize("tz", [None, timezone.utc])
44+
def test_array_strptime_resolution_mixed(self, tz):
45+
dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz)
46+
47+
ts = Timestamp(dt).as_unit("ns")
48+
49+
arr = np.array([dt, ts], dtype=object)
50+
expected = np.array(
51+
[Timestamp(dt).as_unit("ns").asm8, ts.asm8],
52+
dtype="M8[ns]",
53+
)
54+
55+
fmt = "%Y-%m-%d %H:%M:%S"
56+
res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
57+
tm.assert_numpy_array_equal(res, expected)
58+
59+
fmt = "ISO8601"
60+
res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
61+
tm.assert_numpy_array_equal(res, expected)

0 commit comments

Comments
 (0)