Skip to content

Commit 698920f

Browse files
jbrockmendeljreback
authored andcommitted
REF: do all casting _before_ call to DatetimeEngine.get_loc (#30948)
1 parent 9b0ef5d commit 698920f

File tree

4 files changed

+86
-50
lines changed

4 files changed

+86
-50
lines changed

pandas/_libs/index.pyx

+29-23
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ cnp.import_array()
1717

1818
cimport pandas._libs.util as util
1919

20-
from pandas._libs.tslibs.conversion cimport maybe_datetimelike_to_i8
2120
from pandas._libs.tslibs.nattype cimport c_NaT as NaT
21+
from pandas._libs.tslibs.c_timestamp cimport _Timestamp
2222

2323
from pandas._libs.hashtable cimport HashTable
2424

@@ -407,20 +407,27 @@ cdef class DatetimeEngine(Int64Engine):
407407
cdef _get_box_dtype(self):
408408
return 'M8[ns]'
409409

410+
cdef int64_t _unbox_scalar(self, scalar) except? -1:
411+
# NB: caller is responsible for ensuring tzawareness compat
412+
# before we get here
413+
if not (isinstance(scalar, _Timestamp) or scalar is NaT):
414+
raise TypeError(scalar)
415+
return scalar.value
416+
410417
def __contains__(self, object val):
411418
cdef:
412-
int64_t loc
419+
int64_t loc, conv
413420

421+
conv = self._unbox_scalar(val)
414422
if self.over_size_threshold and self.is_monotonic_increasing:
415423
if not self.is_unique:
416-
return self._get_loc_duplicates(val)
424+
return self._get_loc_duplicates(conv)
417425
values = self._get_index_values()
418-
conv = maybe_datetimelike_to_i8(val)
419426
loc = values.searchsorted(conv, side='left')
420427
return values[loc] == conv
421428

422429
self._ensure_mapping_populated()
423-
return maybe_datetimelike_to_i8(val) in self.mapping
430+
return conv in self.mapping
424431

425432
cdef _get_index_values(self):
426433
return self.vgetter().view('i8')
@@ -429,45 +436,39 @@ cdef class DatetimeEngine(Int64Engine):
429436
return algos.is_monotonic(values, timelike=True)
430437

431438
cpdef get_loc(self, object val):
439+
# NB: the caller is responsible for ensuring that we are called
440+
# with either a Timestamp or NaT (Timedelta or NaT for TimedeltaEngine)
441+
432442
cdef:
433443
int64_t loc
434444
if is_definitely_invalid_key(val):
435445
raise TypeError
436446

447+
try:
448+
conv = self._unbox_scalar(val)
449+
except TypeError:
450+
raise KeyError(val)
451+
437452
# Welcome to the spaghetti factory
438453
if self.over_size_threshold and self.is_monotonic_increasing:
439454
if not self.is_unique:
440-
val = maybe_datetimelike_to_i8(val)
441-
return self._get_loc_duplicates(val)
455+
return self._get_loc_duplicates(conv)
442456
values = self._get_index_values()
443457

444-
try:
445-
conv = maybe_datetimelike_to_i8(val)
446-
loc = values.searchsorted(conv, side='left')
447-
except TypeError:
448-
raise KeyError(val)
458+
loc = values.searchsorted(conv, side='left')
449459

450460
if loc == len(values) or values[loc] != conv:
451461
raise KeyError(val)
452462
return loc
453463

454464
self._ensure_mapping_populated()
455465
if not self.unique:
456-
val = maybe_datetimelike_to_i8(val)
457-
return self._get_loc_duplicates(val)
466+
return self._get_loc_duplicates(conv)
458467

459468
try:
460-
return self.mapping.get_item(val.value)
469+
return self.mapping.get_item(conv)
461470
except KeyError:
462471
raise KeyError(val)
463-
except AttributeError:
464-
pass
465-
466-
try:
467-
val = maybe_datetimelike_to_i8(val)
468-
return self.mapping.get_item(val)
469-
except (TypeError, ValueError):
470-
raise KeyError(val)
471472

472473
def get_indexer(self, values):
473474
self._ensure_mapping_populated()
@@ -494,6 +495,11 @@ cdef class TimedeltaEngine(DatetimeEngine):
494495
cdef _get_box_dtype(self):
495496
return 'm8[ns]'
496497

498+
cdef int64_t _unbox_scalar(self, scalar) except? -1:
499+
if not (isinstance(scalar, Timedelta) or scalar is NaT):
500+
raise TypeError(scalar)
501+
return scalar.value
502+
497503

498504
cdef class PeriodEngine(Int64Engine):
499505

pandas/_libs/tslibs/conversion.pxd

-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,4 @@ cdef int64_t get_datetime64_nanos(object val) except? -1
2525

2626
cpdef int64_t pydt_to_i8(object pydt) except? -1
2727

28-
cdef maybe_datetimelike_to_i8(object val)
29-
3028
cpdef datetime localize_pydatetime(datetime dt, object tz)

pandas/_libs/tslibs/conversion.pyx

-25
Original file line numberDiff line numberDiff line change
@@ -207,31 +207,6 @@ def datetime_to_datetime64(object[:] values):
207207
return result, inferred_tz
208208

209209

210-
cdef inline maybe_datetimelike_to_i8(object val):
211-
"""
212-
Try to convert to a nanosecond timestamp. Fall back to returning the
213-
input value.
214-
215-
Parameters
216-
----------
217-
val : object
218-
219-
Returns
220-
-------
221-
val : int64 timestamp or original input
222-
"""
223-
cdef:
224-
npy_datetimestruct dts
225-
try:
226-
return val.value
227-
except AttributeError:
228-
if is_datetime64_object(val):
229-
return get_datetime64_value(val)
230-
elif PyDateTime_Check(val):
231-
return convert_datetime_to_tsobject(val, None).value
232-
return val
233-
234-
235210
# ----------------------------------------------------------------------
236211
# _TSObject Conversion
237212

pandas/tests/indexes/test_engines.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import re
2+
3+
import pytest
4+
5+
import pandas as pd
6+
7+
8+
class TestDatetimeEngine:
9+
@pytest.mark.parametrize(
10+
"scalar",
11+
[
12+
pd.Timedelta(pd.Timestamp("2016-01-01").asm8.view("m8[ns]")),
13+
pd.Timestamp("2016-01-01").value,
14+
pd.Timestamp("2016-01-01").to_pydatetime(),
15+
pd.Timestamp("2016-01-01").to_datetime64(),
16+
],
17+
)
18+
def test_not_contains_requires_timestamp(self, scalar):
19+
dti1 = pd.date_range("2016-01-01", periods=3)
20+
dti2 = dti1.insert(1, pd.NaT) # non-monotonic
21+
dti3 = dti1.insert(3, dti1[0]) # non-unique
22+
dti4 = pd.date_range("2016-01-01", freq="ns", periods=2_000_000)
23+
dti5 = dti4.insert(0, dti4[0]) # over size threshold, not unique
24+
25+
msg = "|".join([re.escape(str(scalar)), re.escape(repr(scalar))])
26+
for dti in [dti1, dti2, dti3, dti4, dti5]:
27+
with pytest.raises(TypeError, match=msg):
28+
scalar in dti._engine
29+
30+
with pytest.raises(KeyError, match=msg):
31+
dti._engine.get_loc(scalar)
32+
33+
34+
class TestTimedeltaEngine:
35+
@pytest.mark.parametrize(
36+
"scalar",
37+
[
38+
pd.Timestamp(pd.Timedelta(days=42).asm8.view("datetime64[ns]")),
39+
pd.Timedelta(days=42).value,
40+
pd.Timedelta(days=42).to_pytimedelta(),
41+
pd.Timedelta(days=42).to_timedelta64(),
42+
],
43+
)
44+
def test_not_contains_requires_timestamp(self, scalar):
45+
tdi1 = pd.timedelta_range("42 days", freq="9h", periods=1234)
46+
tdi2 = tdi1.insert(1, pd.NaT) # non-monotonic
47+
tdi3 = tdi1.insert(3, tdi1[0]) # non-unique
48+
tdi4 = pd.timedelta_range("42 days", freq="ns", periods=2_000_000)
49+
tdi5 = tdi4.insert(0, tdi4[0]) # over size threshold, not unique
50+
51+
msg = "|".join([re.escape(str(scalar)), re.escape(repr(scalar))])
52+
for tdi in [tdi1, tdi2, tdi3, tdi4, tdi5]:
53+
with pytest.raises(TypeError, match=msg):
54+
scalar in tdi._engine
55+
56+
with pytest.raises(KeyError, match=msg):
57+
tdi._engine.get_loc(scalar)

0 commit comments

Comments
 (0)