Skip to content

Commit 14dc069

Browse files
authored
REF: simplify DTI/PI.get_loc (#49941)
* REF: simplify DTI.get_loc * annotate * update test message
1 parent 8b227f3 commit 14dc069

File tree

3 files changed

+32
-50
lines changed

3 files changed

+32
-50
lines changed

pandas/core/indexes/datetimes.py

+25-37
Original file line numberDiff line numberDiff line change
@@ -526,29 +526,35 @@ def _parsed_string_to_bounds(self, reso: Resolution, parsed: dt.datetime):
526526
"The index must be timezone aware when indexing "
527527
"with a date string with a UTC offset"
528528
)
529-
start = self._maybe_cast_for_get_loc(start)
530-
end = self._maybe_cast_for_get_loc(end)
529+
# The flipped case with parsed.tz is None and self.tz is not None
530+
# is ruled out bc parsed and reso are produced by _parse_with_reso,
531+
# which localizes parsed.
531532
return start, end
532533

533-
def _disallow_mismatched_indexing(self, key, one_way: bool = False) -> None:
534+
def _parse_with_reso(self, label: str):
535+
parsed, reso = super()._parse_with_reso(label)
536+
537+
parsed = Timestamp(parsed)
538+
539+
if self.tz is not None and parsed.tzinfo is None:
540+
# we special-case timezone-naive strings and timezone-aware
541+
# DatetimeIndex
542+
# https://github.com/pandas-dev/pandas/pull/36148#issuecomment-687883081
543+
parsed = parsed.tz_localize(self.tz)
544+
545+
return parsed, reso
546+
547+
def _disallow_mismatched_indexing(self, key) -> None:
534548
"""
535549
Check for mismatched-tzawareness indexing and re-raise as KeyError.
536550
"""
551+
# we get here with isinstance(key, self._data._recognized_scalars)
537552
try:
538-
self._deprecate_mismatched_indexing(key, one_way=one_way)
553+
# GH#36148
554+
self._data._assert_tzawareness_compat(key)
539555
except TypeError as err:
540556
raise KeyError(key) from err
541557

542-
def _deprecate_mismatched_indexing(self, key, one_way: bool = False) -> None:
543-
# GH#36148
544-
# we get here with isinstance(key, self._data._recognized_scalars)
545-
if self.tz is not None and one_way:
546-
# we special-case timezone-naive strings and timezone-aware
547-
# DatetimeIndex
548-
return
549-
550-
self._data._assert_tzawareness_compat(key)
551-
552558
def get_loc(self, key, method=None, tolerance=None):
553559
"""
554560
Get integer location for requested label
@@ -566,15 +572,15 @@ def get_loc(self, key, method=None, tolerance=None):
566572
if isinstance(key, self._data._recognized_scalars):
567573
# needed to localize naive datetimes
568574
self._disallow_mismatched_indexing(key)
569-
key = self._maybe_cast_for_get_loc(key)
575+
key = Timestamp(key)
570576

571577
elif isinstance(key, str):
572578

573579
try:
574580
parsed, reso = self._parse_with_reso(key)
575581
except ValueError as err:
576582
raise KeyError(key) from err
577-
self._disallow_mismatched_indexing(parsed, one_way=True)
583+
self._disallow_mismatched_indexing(parsed)
578584

579585
if self._can_partial_date_slice(reso):
580586
try:
@@ -583,7 +589,7 @@ def get_loc(self, key, method=None, tolerance=None):
583589
if method is None:
584590
raise KeyError(key) from err
585591

586-
key = self._maybe_cast_for_get_loc(key)
592+
key = parsed
587593

588594
elif isinstance(key, dt.timedelta):
589595
# GH#20464
@@ -607,24 +613,6 @@ def get_loc(self, key, method=None, tolerance=None):
607613
except KeyError as err:
608614
raise KeyError(orig_key) from err
609615

610-
def _maybe_cast_for_get_loc(self, key) -> Timestamp:
611-
# needed to localize naive datetimes or dates (GH 35690)
612-
try:
613-
key = Timestamp(key)
614-
except ValueError as err:
615-
# FIXME(dateutil#1180): we get here because parse_with_reso
616-
# doesn't raise on "t2m"
617-
if not isinstance(key, str):
618-
# Not expected to be reached, but check to be sure
619-
raise # pragma: no cover
620-
raise KeyError(key) from err
621-
622-
if key.tzinfo is None:
623-
key = key.tz_localize(self.tz)
624-
else:
625-
key = key.tz_convert(self.tz)
626-
return key
627-
628616
@doc(DatetimeTimedeltaMixin._maybe_cast_slice_bound)
629617
def _maybe_cast_slice_bound(self, label, side: str):
630618

@@ -635,8 +623,8 @@ def _maybe_cast_slice_bound(self, label, side: str):
635623
label = Timestamp(label).to_pydatetime()
636624

637625
label = super()._maybe_cast_slice_bound(label, side)
638-
self._deprecate_mismatched_indexing(label)
639-
return self._maybe_cast_for_get_loc(label)
626+
self._data._assert_tzawareness_compat(label)
627+
return Timestamp(label)
640628

641629
def slice_indexer(self, start=None, end=None, step=None):
642630
"""

pandas/core/indexes/period.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -420,18 +420,14 @@ def get_loc(self, key, method=None, tolerance=None):
420420
if reso == self._resolution_obj:
421421
# the reso < self._resolution_obj case goes
422422
# through _get_string_slice
423-
key = self._cast_partial_indexing_scalar(key)
424-
loc = self.get_loc(key, method=method, tolerance=tolerance)
425-
# Recursing instead of falling through matters for the exception
426-
# message in test_get_loc3 (though not clear if that really matters)
427-
return loc
423+
key = self._cast_partial_indexing_scalar(parsed)
428424
elif method is None:
429425
raise KeyError(key)
430426
else:
431427
key = self._cast_partial_indexing_scalar(parsed)
432428

433429
elif isinstance(key, Period):
434-
key = self._maybe_cast_for_get_loc(key)
430+
self._disallow_mismatched_indexing(key)
435431

436432
elif isinstance(key, datetime):
437433
key = self._cast_partial_indexing_scalar(key)
@@ -445,8 +441,7 @@ def get_loc(self, key, method=None, tolerance=None):
445441
except KeyError as err:
446442
raise KeyError(orig_key) from err
447443

448-
def _maybe_cast_for_get_loc(self, key: Period) -> Period:
449-
# name is a misnomer, chosen for compat with DatetimeIndex
444+
def _disallow_mismatched_indexing(self, key: Period) -> None:
450445
sfreq = self.freq
451446
kfreq = key.freq
452447
if not (
@@ -460,15 +455,14 @@ def _maybe_cast_for_get_loc(self, key: Period) -> Period:
460455
# checking these two attributes is sufficient to check equality,
461456
# and much more performant than `self.freq == key.freq`
462457
raise KeyError(key)
463-
return key
464458

465-
def _cast_partial_indexing_scalar(self, label):
459+
def _cast_partial_indexing_scalar(self, label: datetime) -> Period:
466460
try:
467-
key = Period(label, freq=self.freq)
461+
period = Period(label, freq=self.freq)
468462
except ValueError as err:
469463
# we cannot construct the Period
470464
raise KeyError(label) from err
471-
return key
465+
return period
472466

473467
@doc(DatetimeIndexOpsMixin._maybe_cast_slice_bound)
474468
def _maybe_cast_slice_bound(self, label, side: str):

pandas/tests/indexes/period/test_indexing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def test_get_loc3(self):
373373
msg = "Input has different freq=None from PeriodArray\\(freq=D\\)"
374374
with pytest.raises(ValueError, match=msg):
375375
idx.get_loc("2000-01-10", method="nearest", tolerance="1 hour")
376-
with pytest.raises(KeyError, match=r"^Period\('2000-01-10', 'D'\)$"):
376+
with pytest.raises(KeyError, match=r"^'2000-01-10'$"):
377377
idx.get_loc("2000-01-10", method="nearest", tolerance="1 day")
378378
with pytest.raises(
379379
ValueError, match="list-like tolerance size must match target index size"

0 commit comments

Comments
 (0)