Skip to content

Commit 2276a50

Browse files
authored
REF: implement IntervalIndex._get_indexer_unique_sides (#43182)
1 parent 587f1db commit 2276a50

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

pandas/core/indexes/interval.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -648,12 +648,11 @@ def _get_indexer(
648648
) -> npt.NDArray[np.intp]:
649649

650650
if isinstance(target, IntervalIndex):
651-
# non-overlapping -> at most one match per interval in target
651+
# We only get here with not self.is_overlapping
652+
# -> at most one match per interval in target
652653
# want exact matches -> need both left/right to match, so defer to
653654
# left/right get_indexer, compare elementwise, equality -> match
654-
left_indexer = self.left.get_indexer(target.left)
655-
right_indexer = self.right.get_indexer(target.right)
656-
indexer = np.where(left_indexer == right_indexer, left_indexer, -1)
655+
indexer = self._get_indexer_unique_sides(target)
657656

658657
elif not is_object_dtype(target.dtype):
659658
# homogeneous scalar index: use IntervalTree
@@ -678,6 +677,14 @@ def get_indexer_non_unique(
678677
# -> no matches
679678
return self._get_indexer_non_comparable(target, None, unique=False)
680679

680+
elif isinstance(target, IntervalIndex):
681+
if self.left.is_unique and self.right.is_unique:
682+
# fastpath available even if we don't have self._index_as_unique
683+
indexer = self._get_indexer_unique_sides(target)
684+
missing = (indexer == -1).nonzero()[0]
685+
else:
686+
return self._get_indexer_pointwise(target)
687+
681688
elif is_object_dtype(target.dtype) or not self._should_partial_index(target):
682689
# target might contain intervals: defer elementwise to get_loc
683690
return self._get_indexer_pointwise(target)
@@ -690,6 +697,18 @@ def get_indexer_non_unique(
690697

691698
return ensure_platform_int(indexer), ensure_platform_int(missing)
692699

700+
def _get_indexer_unique_sides(self, target: IntervalIndex) -> npt.NDArray[np.intp]:
701+
"""
702+
_get_indexer specialized to the case where both of our sides are unique.
703+
"""
704+
# Caller is responsible for checking
705+
# `self.left.is_unique and self.right.is_unique`
706+
707+
left_indexer = self.left.get_indexer(target.left)
708+
right_indexer = self.right.get_indexer(target.right)
709+
indexer = np.where(left_indexer == right_indexer, left_indexer, -1)
710+
return indexer
711+
693712
def _get_indexer_pointwise(
694713
self, target: Index
695714
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]:

0 commit comments

Comments
 (0)