Skip to content

Commit b42505e

Browse files
authored
REF: simplify PeriodIndex.get_loc (#31598)
1 parent 980ab6b commit b42505e

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

pandas/_libs/index.pyx

+23
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cnp.import_array()
1212

1313
cimport pandas._libs.util as util
1414

15+
from pandas._libs.tslibs import Period
1516
from pandas._libs.tslibs.nattype cimport c_NaT as NaT
1617
from pandas._libs.tslibs.c_timestamp cimport _Timestamp
1718

@@ -466,6 +467,28 @@ cdef class TimedeltaEngine(DatetimeEngine):
466467

467468
cdef class PeriodEngine(Int64Engine):
468469

470+
cdef int64_t _unbox_scalar(self, scalar) except? -1:
471+
if scalar is NaT:
472+
return scalar.value
473+
if isinstance(scalar, Period):
474+
# NB: we assume that we have the correct freq here.
475+
# TODO: potential optimize by checking for _Period?
476+
return scalar.ordinal
477+
raise TypeError(scalar)
478+
479+
cpdef get_loc(self, object val):
480+
# NB: the caller is responsible for ensuring that we are called
481+
# with either a Period or NaT
482+
cdef:
483+
int64_t conv
484+
485+
try:
486+
conv = self._unbox_scalar(val)
487+
except TypeError:
488+
raise KeyError(val)
489+
490+
return Int64Engine.get_loc(self, conv)
491+
469492
cdef _get_index_values(self):
470493
return super(PeriodEngine, self).vgetter().view("i8")
471494

pandas/core/indexes/period.py

+8-17
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,10 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None):
468468

469469
if tolerance is not None:
470470
tolerance = self._convert_tolerance(tolerance, target)
471+
if self_index is not self:
472+
# convert tolerance to i8
473+
tolerance = self._maybe_convert_timedelta(tolerance)
474+
471475
return Index.get_indexer(self_index, target, method, limit, tolerance)
472476

473477
@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
@@ -504,6 +508,7 @@ def get_loc(self, key, method=None, tolerance=None):
504508
TypeError
505509
If key is listlike or otherwise not hashable.
506510
"""
511+
orig_key = key
507512

508513
if not is_scalar(key):
509514
raise InvalidIndexError(key)
@@ -545,20 +550,12 @@ def get_loc(self, key, method=None, tolerance=None):
545550
key = Period(key, freq=self.freq)
546551
except ValueError:
547552
# we cannot construct the Period
548-
raise KeyError(key)
553+
raise KeyError(orig_key)
549554

550-
ordinal = self._data._unbox_scalar(key)
551555
try:
552-
return self._engine.get_loc(ordinal)
556+
return Index.get_loc(self, key, method, tolerance)
553557
except KeyError:
554-
555-
try:
556-
if tolerance is not None:
557-
tolerance = self._convert_tolerance(tolerance, np.asarray(key))
558-
return self._int64index.get_loc(ordinal, method, tolerance)
559-
560-
except KeyError:
561-
raise KeyError(key)
558+
raise KeyError(orig_key)
562559

563560
def _maybe_cast_slice_bound(self, label, side: str, kind: str):
564561
"""
@@ -625,12 +622,6 @@ def _get_string_slice(self, key: str, use_lhs: bool = True, use_rhs: bool = True
625622
except KeyError:
626623
raise KeyError(key)
627624

628-
def _convert_tolerance(self, tolerance, target):
629-
tolerance = DatetimeIndexOpsMixin._convert_tolerance(self, tolerance, target)
630-
if target.size != tolerance.size and tolerance.size > 1:
631-
raise ValueError("list-like tolerance size must match target index size")
632-
return self._maybe_convert_timedelta(tolerance)
633-
634625
def insert(self, loc, item):
635626
if not isinstance(item, Period) or self.freq != item.freq:
636627
return self.astype(object).insert(loc, item)

0 commit comments

Comments
 (0)