Skip to content

Commit 070e1fe

Browse files
authored
BUG: DTI/TDI.get_loc with mismatched-reso scalars (#48815)
* BUG: DTI/TDI.get_loc with mismatched-reso scalars * mypy fixup
1 parent 8b0ad71 commit 070e1fe

File tree

4 files changed

+85
-9
lines changed

4 files changed

+85
-9
lines changed

pandas/_libs/index.pyx

+38-9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ cnp.import_array()
1717
from pandas._libs cimport util
1818
from pandas._libs.hashtable cimport HashTable
1919
from pandas._libs.tslibs.nattype cimport c_NaT as NaT
20+
from pandas._libs.tslibs.np_datetime cimport (
21+
NPY_DATETIMEUNIT,
22+
get_unit_from_dtype,
23+
)
2024
from pandas._libs.tslibs.period cimport is_period_object
2125
from pandas._libs.tslibs.timedeltas cimport _Timedelta
2226
from pandas._libs.tslibs.timestamps cimport _Timestamp
@@ -485,17 +489,35 @@ cdef class ObjectEngine(IndexEngine):
485489

486490
cdef class DatetimeEngine(Int64Engine):
487491

492+
cdef:
493+
NPY_DATETIMEUNIT reso
494+
495+
def __init__(self, ndarray values):
496+
super().__init__(values.view("i8"))
497+
self.reso = get_unit_from_dtype(values.dtype)
498+
488499
cdef int64_t _unbox_scalar(self, scalar) except? -1:
489500
# NB: caller is responsible for ensuring tzawareness compat
490501
# before we get here
491-
if not (isinstance(scalar, _Timestamp) or scalar is NaT):
492-
raise TypeError(scalar)
493-
return scalar.value
502+
if scalar is NaT:
503+
return NaT.value
504+
elif isinstance(scalar, _Timestamp):
505+
if scalar._reso == self.reso:
506+
return scalar.value
507+
else:
508+
# Note: caller is responsible for catching potential ValueError
509+
# from _as_reso
510+
return (<_Timestamp>scalar)._as_reso(self.reso, round_ok=False).value
511+
raise TypeError(scalar)
494512

495513
def __contains__(self, val: object) -> bool:
496514
# We assume before we get here:
497515
# - val is hashable
498-
self._unbox_scalar(val)
516+
try:
517+
self._unbox_scalar(val)
518+
except ValueError:
519+
return False
520+
499521
try:
500522
self.get_loc(val)
501523
return True
@@ -517,8 +539,8 @@ cdef class DatetimeEngine(Int64Engine):
517539

518540
try:
519541
conv = self._unbox_scalar(val)
520-
except TypeError:
521-
raise KeyError(val)
542+
except (TypeError, ValueError) as err:
543+
raise KeyError(val) from err
522544

523545
# Welcome to the spaghetti factory
524546
if self.over_size_threshold and self.is_monotonic_increasing:
@@ -545,9 +567,16 @@ cdef class DatetimeEngine(Int64Engine):
545567
cdef class TimedeltaEngine(DatetimeEngine):
546568

547569
cdef int64_t _unbox_scalar(self, scalar) except? -1:
548-
if not (isinstance(scalar, _Timedelta) or scalar is NaT):
549-
raise TypeError(scalar)
550-
return scalar.value
570+
if scalar is NaT:
571+
return NaT.value
572+
elif isinstance(scalar, _Timedelta):
573+
if scalar._reso == self.reso:
574+
return scalar.value
575+
else:
576+
# Note: caller is responsible for catching potential ValueError
577+
# from _as_reso
578+
return (<_Timedelta>scalar)._as_reso(self.reso, round_ok=False).value
579+
raise TypeError(scalar)
551580

552581

553582
cdef class PeriodEngine(Int64Engine):

pandas/core/indexes/base.py

+7
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,13 @@ def _engine(
911911
return libindex.Complex64Engine(target_values)
912912
elif target_values.dtype == np.complex128:
913913
return libindex.Complex128Engine(target_values)
914+
elif needs_i8_conversion(self.dtype):
915+
# We need to keep M8/m8 dtype when initializing the Engine,
916+
# but don't want to change _get_engine_target bc it is used
917+
# elsewhere
918+
# error: Item "ExtensionArray" of "Union[ExtensionArray,
919+
# ndarray[Any, Any]]" has no attribute "_ndarray" [union-attr]
920+
target_values = self._data._ndarray # type: ignore[union-attr]
914921

915922
# error: Argument 1 to "ExtensionEngine" has incompatible type
916923
# "ndarray[Any, Any]"; expected "ExtensionArray"

pandas/tests/indexes/datetimes/test_indexing.py

+19
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,25 @@ def test_take_fill_value_with_timezone(self):
388388

389389

390390
class TestGetLoc:
391+
def test_get_loc_key_unit_mismatch(self):
392+
idx = date_range("2000-01-01", periods=3)
393+
key = idx[1]._as_unit("ms")
394+
loc = idx.get_loc(key)
395+
assert loc == 1
396+
assert key in idx
397+
398+
def test_get_loc_key_unit_mismatch_not_castable(self):
399+
dta = date_range("2000-01-01", periods=3)._data.astype("M8[s]")
400+
dti = DatetimeIndex(dta)
401+
key = dta[0]._as_unit("ns") + pd.Timedelta(1)
402+
403+
with pytest.raises(
404+
KeyError, match=r"Timestamp\('2000-01-01 00:00:00.000000001'\)"
405+
):
406+
dti.get_loc(key)
407+
408+
assert key not in dti
409+
391410
@pytest.mark.parametrize("method", [None, "pad", "backfill", "nearest"])
392411
@pytest.mark.filterwarnings("ignore:Passing method:FutureWarning")
393412
def test_get_loc_method_exact_match(self, method):

pandas/tests/indexes/timedeltas/test_indexing.py

+21
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,27 @@ def test_timestamp_invalid_key(self, key):
7575

7676

7777
class TestGetLoc:
78+
def test_get_loc_key_unit_mismatch(self):
79+
idx = to_timedelta(["0 days", "1 days", "2 days"])
80+
key = idx[1]._as_unit("ms")
81+
loc = idx.get_loc(key)
82+
assert loc == 1
83+
84+
def test_get_loc_key_unit_mismatch_not_castable(self):
85+
# TODO(2.0): once TDA.astype supports m8[s] directly, tdi
86+
# can be constructed directly
87+
tda = to_timedelta(["0 days", "1 days", "2 days"])._data
88+
arr = np.array(tda).astype("m8[s]")
89+
tda2 = type(tda)._simple_new(arr, dtype=arr.dtype)
90+
tdi = TimedeltaIndex(tda2)
91+
assert tdi.dtype == "m8[s]"
92+
key = tda[0]._as_unit("ns") + Timedelta(1)
93+
94+
with pytest.raises(KeyError, match=r"Timedelta\('0 days 00:00:00.000000001'\)"):
95+
tdi.get_loc(key)
96+
97+
assert key not in tdi
98+
7899
@pytest.mark.filterwarnings("ignore:Passing method:FutureWarning")
79100
def test_get_loc(self):
80101
idx = to_timedelta(["0 days", "1 days", "2 days"])

0 commit comments

Comments
 (0)