Skip to content

REF: do all convert_tolerance casting inside Index.get_loc #31425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 1, 2020
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Indexing
- Bug in :meth:`PeriodIndex.get_loc` treating higher-resolution strings differently from :meth:`PeriodIndex.get_value` (:issue:`31172`)
- Bug in :meth:`Series.at` and :meth:`DataFrame.at` not matching ``.loc`` behavior when looking up an integer in a :class:`Float64Index` (:issue:`31329`)
- Bug in :meth:`PeriodIndex.is_monotonic` incorrectly returning ``True`` when containing leading ``NaT`` entries (:issue:`31437`)
- Bug in :meth:`DatetimeIndex.get_loc` raising ``KeyError`` with converted-integer key instead of the user-passed key (:issue:`31425`)
-

Missing
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2884,6 +2884,10 @@ def get_loc(self, key, method=None, tolerance=None):
return self._engine.get_loc(key)
except KeyError:
return self._engine.get_loc(self._maybe_cast_indexer(key))

if tolerance is not None:
tolerance = self._convert_tolerance(tolerance, np.asarray(key))

indexer = self.get_indexer([key], method=method, tolerance=tolerance)
if indexer.ndim > 1 or indexer.size > 1:
raise TypeError("get_loc requires scalar valued input")
Expand Down
23 changes: 12 additions & 11 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,18 +639,13 @@ def get_loc(self, key, method=None, tolerance=None):
if not is_scalar(key):
raise InvalidIndexError(key)

orig_key = key
if is_valid_nat_for_dtype(key, self.dtype):
key = NaT

if tolerance is not None:
# try converting tolerance now, so errors don't get swallowed by
# the try/except clauses below
tolerance = self._convert_tolerance(tolerance, np.asarray(key))

if isinstance(key, (datetime, np.datetime64)):
# needed to localize naive datetimes
key = self._maybe_cast_for_get_loc(key)
return Index.get_loc(self, key, method, tolerance)

elif isinstance(key, str):
try:
Expand All @@ -659,9 +654,8 @@ def get_loc(self, key, method=None, tolerance=None):
pass

try:
stamp = self._maybe_cast_for_get_loc(key)
return Index.get_loc(self, stamp, method, tolerance)
except (KeyError, ValueError):
key = self._maybe_cast_for_get_loc(key)
except ValueError:
raise KeyError(key)

elif isinstance(key, timedelta):
Expand All @@ -670,14 +664,21 @@ def get_loc(self, key, method=None, tolerance=None):
f"Cannot index {type(self).__name__} with {type(key).__name__}"
)

if isinstance(key, time):
elif isinstance(key, time):
if method is not None:
raise NotImplementedError(
"cannot yet lookup inexact labels when key is a time object"
)
return self.indexer_at_time(key)

return Index.get_loc(self, key, method, tolerance)
else:
# unrecognized type
raise KeyError(key)

try:
return Index.get_loc(self, key, method, tolerance)
except KeyError:
raise KeyError(orig_key)

def _maybe_cast_for_get_loc(self, key) -> Timestamp:
# needed to localize naive datetimes
Expand Down
7 changes: 0 additions & 7 deletions pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
""" implement the TimedeltaIndex """

import numpy as np

from pandas._libs import NaT, Timedelta, index as libindex
from pandas.util._decorators import Appender

Expand Down Expand Up @@ -262,11 +260,6 @@ def get_loc(self, key, method=None, tolerance=None):
else:
raise KeyError(key)

if tolerance is not None:
# try converting tolerance now, so errors don't get swallowed by
# the try/except clauses below
tolerance = self._convert_tolerance(tolerance, np.asarray(key))

return Index.get_loc(self, key, method, tolerance)

def _maybe_cast_slice_bound(self, label, side: str, kind):
Expand Down
10 changes: 6 additions & 4 deletions pandas/tests/series/indexing/test_datetime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timedelta
import re

import numpy as np
import pytest
Expand Down Expand Up @@ -380,7 +381,7 @@ def test_datetime_indexing():
s = Series(len(index), index=index)
stamp = Timestamp("1/8/2000")

with pytest.raises(KeyError, match=r"^947289600000000000$"):
with pytest.raises(KeyError, match=re.escape(repr(stamp))):
s[stamp]
s[stamp] = 0
assert s[stamp] == 0
Expand All @@ -389,7 +390,7 @@ def test_datetime_indexing():
s = Series(len(index), index=index)
s = s[::-1]

with pytest.raises(KeyError, match=r"^947289600000000000$"):
with pytest.raises(KeyError, match=re.escape(repr(stamp))):
s[stamp]
s[stamp] = 0
assert s[stamp] == 0
Expand Down Expand Up @@ -495,8 +496,9 @@ def test_duplicate_dates_indexing(dups):
expected = Series(np.where(mask, 0, ts), index=ts.index)
tm.assert_series_equal(cp, expected)

with pytest.raises(KeyError, match=r"^947116800000000000$"):
ts[datetime(2000, 1, 6)]
key = datetime(2000, 1, 6)
with pytest.raises(KeyError, match=re.escape(repr(key))):
ts[key]

# new index
ts[datetime(2000, 1, 6)] = 0
Expand Down