Skip to content

Commit 0283df6

Browse files
authored
REF: implement Index._should_partial_index (#42227)
1 parent 55b4edd commit 0283df6

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

pandas/core/indexes/base.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -3440,7 +3440,7 @@ def get_indexer(
34403440
if len(target) == 0:
34413441
return np.array([], dtype=np.intp)
34423442

3443-
if not self._should_compare(target) and not is_interval_dtype(self.dtype):
3443+
if not self._should_compare(target) and not self._should_partial_index(target):
34443444
# IntervalIndex get special treatment bc numeric scalars can be
34453445
# matched to Interval scalars
34463446
return self._get_indexer_non_comparable(target, method=method, unique=True)
@@ -3519,6 +3519,16 @@ def _get_indexer(
35193519

35203520
return ensure_platform_int(indexer)
35213521

3522+
@final
3523+
def _should_partial_index(self, target: Index) -> bool:
3524+
"""
3525+
Should we attempt partial-matching indexing?
3526+
"""
3527+
if is_interval_dtype(self.dtype):
3528+
# "Index" has no attribute "left"
3529+
return self.left._should_compare(target) # type: ignore[attr-defined]
3530+
return False
3531+
35223532
@final
35233533
def _check_indexing_method(
35243534
self,

pandas/core/indexes/interval.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -648,22 +648,21 @@ def _get_indexer(
648648
# returned ndarray is np.intp
649649

650650
if isinstance(target, IntervalIndex):
651-
if not self._should_compare(target):
652-
return self._get_indexer_non_comparable(target, method, unique=True)
653-
654651
# non-overlapping -> at most one match per interval in target
655652
# want exact matches -> need both left/right to match, so defer to
656653
# left/right get_indexer, compare elementwise, equality -> match
657654
left_indexer = self.left.get_indexer(target.left)
658655
right_indexer = self.right.get_indexer(target.right)
659656
indexer = np.where(left_indexer == right_indexer, left_indexer, -1)
660657

661-
elif not is_object_dtype(target):
658+
elif not is_object_dtype(target.dtype):
662659
# homogeneous scalar index: use IntervalTree
660+
# we should always have self._should_partial_index(target) here
663661
target = self._maybe_convert_i8(target)
664662
indexer = self._engine.get_indexer(target.values)
665663
else:
666664
# heterogeneous scalar index: defer elementwise to get_loc
665+
# we should always have self._should_partial_index(target) here
667666
return self._get_indexer_pointwise(target)[0]
668667

669668
return ensure_platform_int(indexer)
@@ -673,11 +672,12 @@ def get_indexer_non_unique(self, target: Index) -> tuple[np.ndarray, np.ndarray]
673672
# both returned ndarrays are np.intp
674673
target = ensure_index(target)
675674

676-
if isinstance(target, IntervalIndex) and not self._should_compare(target):
677-
# different closed or incompatible subtype -> no matches
675+
if not self._should_compare(target) and not self._should_partial_index(target):
676+
# e.g. IntervalIndex with different closed or incompatible subtype
677+
# -> no matches
678678
return self._get_indexer_non_comparable(target, None, unique=False)
679679

680-
elif is_object_dtype(target.dtype) or isinstance(target, IntervalIndex):
680+
elif is_object_dtype(target.dtype) or not self._should_partial_index(target):
681681
# target might contain intervals: defer elementwise to get_loc
682682
return self._get_indexer_pointwise(target)
683683

0 commit comments

Comments
 (0)