From afff8cef7a9eff32f2f1027f38e634df8054f5f8 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 27 Sep 2022 10:13:13 -0700 Subject: [PATCH 1/2] BUG: DTI/TDI.get_loc with mismatched-reso scalars --- pandas/_libs/index.pyx | 47 +++++++++++++++---- pandas/core/indexes/base.py | 5 ++ .../tests/indexes/datetimes/test_indexing.py | 19 ++++++++ .../tests/indexes/timedeltas/test_indexing.py | 21 +++++++++ 4 files changed, 83 insertions(+), 9 deletions(-) diff --git a/pandas/_libs/index.pyx b/pandas/_libs/index.pyx index 58be977299e99..1d0eb256ca051 100644 --- a/pandas/_libs/index.pyx +++ b/pandas/_libs/index.pyx @@ -17,6 +17,10 @@ cnp.import_array() from pandas._libs cimport util from pandas._libs.hashtable cimport HashTable from pandas._libs.tslibs.nattype cimport c_NaT as NaT +from pandas._libs.tslibs.np_datetime cimport ( + NPY_DATETIMEUNIT, + get_unit_from_dtype, +) from pandas._libs.tslibs.period cimport is_period_object from pandas._libs.tslibs.timedeltas cimport _Timedelta from pandas._libs.tslibs.timestamps cimport _Timestamp @@ -485,17 +489,35 @@ cdef class ObjectEngine(IndexEngine): cdef class DatetimeEngine(Int64Engine): + cdef: + NPY_DATETIMEUNIT reso + + def __init__(self, ndarray values): + super().__init__(values.view("i8")) + self.reso = get_unit_from_dtype(values.dtype) + cdef int64_t _unbox_scalar(self, scalar) except? -1: # NB: caller is responsible for ensuring tzawareness compat # before we get here - if not (isinstance(scalar, _Timestamp) or scalar is NaT): - raise TypeError(scalar) - return scalar.value + if scalar is NaT: + return NaT.value + elif isinstance(scalar, _Timestamp): + if scalar._reso == self.reso: + return scalar.value + else: + # Note: caller is responsible for catching potential ValueError + # from _as_reso + return (<_Timestamp>scalar)._as_reso(self.reso, round_ok=False).value + raise TypeError(scalar) def __contains__(self, val: object) -> bool: # We assume before we get here: # - val is hashable - self._unbox_scalar(val) + try: + self._unbox_scalar(val) + except ValueError: + return False + try: self.get_loc(val) return True @@ -517,8 +539,8 @@ cdef class DatetimeEngine(Int64Engine): try: conv = self._unbox_scalar(val) - except TypeError: - raise KeyError(val) + except (TypeError, ValueError) as err: + raise KeyError(val) from err # Welcome to the spaghetti factory if self.over_size_threshold and self.is_monotonic_increasing: @@ -545,9 +567,16 @@ cdef class DatetimeEngine(Int64Engine): cdef class TimedeltaEngine(DatetimeEngine): cdef int64_t _unbox_scalar(self, scalar) except? -1: - if not (isinstance(scalar, _Timedelta) or scalar is NaT): - raise TypeError(scalar) - return scalar.value + if scalar is NaT: + return NaT.value + elif isinstance(scalar, _Timedelta): + if scalar._reso == self.reso: + return scalar.value + else: + # Note: caller is responsible for catching potential ValueError + # from _as_reso + return (<_Timedelta>scalar)._as_reso(self.reso, round_ok=False).value + raise TypeError(scalar) cdef class PeriodEngine(Int64Engine): diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index a06a54081f84d..bd8044f2e07fe 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -911,6 +911,11 @@ def _engine( return libindex.Complex64Engine(target_values) elif target_values.dtype == np.complex128: return libindex.Complex128Engine(target_values) + elif needs_i8_conversion(self.dtype): + # We need to keep M8/m8 dtype when initializing the Engine, + # but don't want to change _get_engine_target bc it is used + # elsewhere + target_values = self._data._ndarray # error: Argument 1 to "ExtensionEngine" has incompatible type # "ndarray[Any, Any]"; expected "ExtensionArray" diff --git a/pandas/tests/indexes/datetimes/test_indexing.py b/pandas/tests/indexes/datetimes/test_indexing.py index a203fee5b3a61..62fdff528bd84 100644 --- a/pandas/tests/indexes/datetimes/test_indexing.py +++ b/pandas/tests/indexes/datetimes/test_indexing.py @@ -388,6 +388,25 @@ def test_take_fill_value_with_timezone(self): class TestGetLoc: + def test_get_loc_key_unit_mismatch(self): + idx = date_range("2000-01-01", periods=3) + key = idx[1]._as_unit("ms") + loc = idx.get_loc(key) + assert loc == 1 + assert key in idx + + def test_get_loc_key_unit_mismatch_not_castable(self): + dta = date_range("2000-01-01", periods=3)._data.astype("M8[s]") + dti = DatetimeIndex(dta) + key = dta[0]._as_unit("ns") + pd.Timedelta(1) + + with pytest.raises( + KeyError, match=r"Timestamp\('2000-01-01 00:00:00.000000001'\)" + ): + dti.get_loc(key) + + assert key not in dti + @pytest.mark.parametrize("method", [None, "pad", "backfill", "nearest"]) @pytest.mark.filterwarnings("ignore:Passing method:FutureWarning") def test_get_loc_method_exact_match(self, method): diff --git a/pandas/tests/indexes/timedeltas/test_indexing.py b/pandas/tests/indexes/timedeltas/test_indexing.py index 154a6289dfc00..bdf299f6dbbdf 100644 --- a/pandas/tests/indexes/timedeltas/test_indexing.py +++ b/pandas/tests/indexes/timedeltas/test_indexing.py @@ -75,6 +75,27 @@ def test_timestamp_invalid_key(self, key): class TestGetLoc: + def test_get_loc_key_unit_mismatch(self): + idx = to_timedelta(["0 days", "1 days", "2 days"]) + key = idx[1]._as_unit("ms") + loc = idx.get_loc(key) + assert loc == 1 + + def test_get_loc_key_unit_mismatch_not_castable(self): + # TODO(2.0): once TDA.astype supports m8[s] directly, tdi + # can be constructed directly + tda = to_timedelta(["0 days", "1 days", "2 days"])._data + arr = np.array(tda).astype("m8[s]") + tda2 = type(tda)._simple_new(arr, dtype=arr.dtype) + tdi = TimedeltaIndex(tda2) + assert tdi.dtype == "m8[s]" + key = tda[0]._as_unit("ns") + Timedelta(1) + + with pytest.raises(KeyError, match=r"Timedelta\('0 days 00:00:00.000000001'\)"): + tdi.get_loc(key) + + assert key not in tdi + @pytest.mark.filterwarnings("ignore:Passing method:FutureWarning") def test_get_loc(self): idx = to_timedelta(["0 days", "1 days", "2 days"]) From 7e79171e4797fc6d1cf20c25e198d86db265cb33 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 27 Sep 2022 11:09:15 -0700 Subject: [PATCH 2/2] mypy fixup --- pandas/core/indexes/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index bd8044f2e07fe..d2b4a4c7d130e 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -915,7 +915,9 @@ def _engine( # We need to keep M8/m8 dtype when initializing the Engine, # but don't want to change _get_engine_target bc it is used # elsewhere - target_values = self._data._ndarray + # error: Item "ExtensionArray" of "Union[ExtensionArray, + # ndarray[Any, Any]]" has no attribute "_ndarray" [union-attr] + target_values = self._data._ndarray # type: ignore[union-attr] # error: Argument 1 to "ExtensionEngine" has incompatible type # "ndarray[Any, Any]"; expected "ExtensionArray"