Skip to content

BUG: DTI/TDI.get_loc with mismatched-reso scalars #48815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ cnp.import_array()
from pandas._libs cimport util
from pandas._libs.hashtable cimport HashTable
from pandas._libs.tslibs.nattype cimport c_NaT as NaT
from pandas._libs.tslibs.np_datetime cimport (
NPY_DATETIMEUNIT,
get_unit_from_dtype,
)
from pandas._libs.tslibs.period cimport is_period_object
from pandas._libs.tslibs.timedeltas cimport _Timedelta
from pandas._libs.tslibs.timestamps cimport _Timestamp
Expand Down Expand Up @@ -485,17 +489,35 @@ cdef class ObjectEngine(IndexEngine):

cdef class DatetimeEngine(Int64Engine):

cdef:
NPY_DATETIMEUNIT reso

def __init__(self, ndarray values):
super().__init__(values.view("i8"))
self.reso = get_unit_from_dtype(values.dtype)

cdef int64_t _unbox_scalar(self, scalar) except? -1:
# NB: caller is responsible for ensuring tzawareness compat
# before we get here
if not (isinstance(scalar, _Timestamp) or scalar is NaT):
raise TypeError(scalar)
return scalar.value
if scalar is NaT:
return NaT.value
elif isinstance(scalar, _Timestamp):
if scalar._reso == self.reso:
return scalar.value
else:
# Note: caller is responsible for catching potential ValueError
# from _as_reso
return (<_Timestamp>scalar)._as_reso(self.reso, round_ok=False).value
raise TypeError(scalar)

def __contains__(self, val: object) -> bool:
# We assume before we get here:
# - val is hashable
self._unbox_scalar(val)
try:
self._unbox_scalar(val)
except ValueError:
return False

try:
self.get_loc(val)
return True
Expand All @@ -517,8 +539,8 @@ cdef class DatetimeEngine(Int64Engine):

try:
conv = self._unbox_scalar(val)
except TypeError:
raise KeyError(val)
except (TypeError, ValueError) as err:
raise KeyError(val) from err

# Welcome to the spaghetti factory
if self.over_size_threshold and self.is_monotonic_increasing:
Expand All @@ -545,9 +567,16 @@ cdef class DatetimeEngine(Int64Engine):
cdef class TimedeltaEngine(DatetimeEngine):

cdef int64_t _unbox_scalar(self, scalar) except? -1:
if not (isinstance(scalar, _Timedelta) or scalar is NaT):
raise TypeError(scalar)
return scalar.value
if scalar is NaT:
return NaT.value
elif isinstance(scalar, _Timedelta):
if scalar._reso == self.reso:
return scalar.value
else:
# Note: caller is responsible for catching potential ValueError
# from _as_reso
return (<_Timedelta>scalar)._as_reso(self.reso, round_ok=False).value
raise TypeError(scalar)


cdef class PeriodEngine(Int64Engine):
Expand Down
7 changes: 7 additions & 0 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,13 @@ def _engine(
return libindex.Complex64Engine(target_values)
elif target_values.dtype == np.complex128:
return libindex.Complex128Engine(target_values)
elif needs_i8_conversion(self.dtype):
# We need to keep M8/m8 dtype when initializing the Engine,
# but don't want to change _get_engine_target bc it is used
# elsewhere
# error: Item "ExtensionArray" of "Union[ExtensionArray,
# ndarray[Any, Any]]" has no attribute "_ndarray" [union-attr]
target_values = self._data._ndarray # type: ignore[union-attr]

# error: Argument 1 to "ExtensionEngine" has incompatible type
# "ndarray[Any, Any]"; expected "ExtensionArray"
Expand Down
19 changes: 19 additions & 0 deletions pandas/tests/indexes/datetimes/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,25 @@ def test_take_fill_value_with_timezone(self):


class TestGetLoc:
def test_get_loc_key_unit_mismatch(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side-related. Should this work when the key is tz-aware and idx is tz-naive (and vise versa)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ATM it works with mismatched aware/naive, but that is deprecated.

idx = date_range("2000-01-01", periods=3)
key = idx[1]._as_unit("ms")
loc = idx.get_loc(key)
assert loc == 1
assert key in idx

def test_get_loc_key_unit_mismatch_not_castable(self):
dta = date_range("2000-01-01", periods=3)._data.astype("M8[s]")
dti = DatetimeIndex(dta)
key = dta[0]._as_unit("ns") + pd.Timedelta(1)

with pytest.raises(
KeyError, match=r"Timestamp\('2000-01-01 00:00:00.000000001'\)"
):
dti.get_loc(key)

assert key not in dti

@pytest.mark.parametrize("method", [None, "pad", "backfill", "nearest"])
@pytest.mark.filterwarnings("ignore:Passing method:FutureWarning")
def test_get_loc_method_exact_match(self, method):
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/indexes/timedeltas/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,27 @@ def test_timestamp_invalid_key(self, key):


class TestGetLoc:
def test_get_loc_key_unit_mismatch(self):
idx = to_timedelta(["0 days", "1 days", "2 days"])
key = idx[1]._as_unit("ms")
loc = idx.get_loc(key)
assert loc == 1

def test_get_loc_key_unit_mismatch_not_castable(self):
# TODO(2.0): once TDA.astype supports m8[s] directly, tdi
# can be constructed directly
tda = to_timedelta(["0 days", "1 days", "2 days"])._data
arr = np.array(tda).astype("m8[s]")
tda2 = type(tda)._simple_new(arr, dtype=arr.dtype)
tdi = TimedeltaIndex(tda2)
assert tdi.dtype == "m8[s]"
key = tda[0]._as_unit("ns") + Timedelta(1)

with pytest.raises(KeyError, match=r"Timedelta\('0 days 00:00:00.000000001'\)"):
tdi.get_loc(key)

assert key not in tdi

@pytest.mark.filterwarnings("ignore:Passing method:FutureWarning")
def test_get_loc(self):
idx = to_timedelta(["0 days", "1 days", "2 days"])
Expand Down