@@ -648,12 +648,11 @@ def _get_indexer(
648
648
) -> npt .NDArray [np .intp ]:
649
649
650
650
if isinstance (target , IntervalIndex ):
651
- # non-overlapping -> at most one match per interval in target
651
+ # We only get here with not self.is_overlapping
652
+ # -> at most one match per interval in target
652
653
# want exact matches -> need both left/right to match, so defer to
653
654
# left/right get_indexer, compare elementwise, equality -> match
654
- left_indexer = self .left .get_indexer (target .left )
655
- right_indexer = self .right .get_indexer (target .right )
656
- indexer = np .where (left_indexer == right_indexer , left_indexer , - 1 )
655
+ indexer = self ._get_indexer_unique_sides (target )
657
656
658
657
elif not is_object_dtype (target .dtype ):
659
658
# homogeneous scalar index: use IntervalTree
@@ -678,6 +677,14 @@ def get_indexer_non_unique(
678
677
# -> no matches
679
678
return self ._get_indexer_non_comparable (target , None , unique = False )
680
679
680
+ elif isinstance (target , IntervalIndex ):
681
+ if self .left .is_unique and self .right .is_unique :
682
+ # fastpath available even if we don't have self._index_as_unique
683
+ indexer = self ._get_indexer_unique_sides (target )
684
+ missing = (indexer == - 1 ).nonzero ()[0 ]
685
+ else :
686
+ return self ._get_indexer_pointwise (target )
687
+
681
688
elif is_object_dtype (target .dtype ) or not self ._should_partial_index (target ):
682
689
# target might contain intervals: defer elementwise to get_loc
683
690
return self ._get_indexer_pointwise (target )
@@ -690,6 +697,18 @@ def get_indexer_non_unique(
690
697
691
698
return ensure_platform_int (indexer ), ensure_platform_int (missing )
692
699
700
+ def _get_indexer_unique_sides (self , target : IntervalIndex ) -> npt .NDArray [np .intp ]:
701
+ """
702
+ _get_indexer specialized to the case where both of our sides are unique.
703
+ """
704
+ # Caller is responsible for checking
705
+ # `self.left.is_unique and self.right.is_unique`
706
+
707
+ left_indexer = self .left .get_indexer (target .left )
708
+ right_indexer = self .right .get_indexer (target .right )
709
+ indexer = np .where (left_indexer == right_indexer , left_indexer , - 1 )
710
+ return indexer
711
+
693
712
def _get_indexer_pointwise (
694
713
self , target : Index
695
714
) -> tuple [npt .NDArray [np .intp ], npt .NDArray [np .intp ]]:
0 commit comments