Skip to content

REF: implement Index._should_partial_index #42227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3410,7 +3410,7 @@ def get_indexer(
if len(target) == 0:
return np.array([], dtype=np.intp)

if not self._should_compare(target) and not is_interval_dtype(self.dtype):
if not self._should_compare(target) and not self._should_partial_index(target):
# IntervalIndex get special treatment bc numeric scalars can be
# matched to Interval scalars
return self._get_indexer_non_comparable(target, method=method, unique=True)
Expand Down Expand Up @@ -3478,6 +3478,16 @@ def _get_indexer(

return ensure_platform_int(indexer)

@final
def _should_partial_index(self, target: Index) -> bool:
"""
Should we attempt partial-matching indexing?
"""
if is_interval_dtype(self.dtype):
# "Index" has no attribute "left"
return self.left._should_compare(target) # type: ignore[attr-defined]
return False

@final
def _check_indexing_method(
self,
Expand Down
14 changes: 7 additions & 7 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,22 +646,21 @@ def _get_indexer(
# returned ndarray is np.intp

if isinstance(target, IntervalIndex):
if not self._should_compare(target):
return self._get_indexer_non_comparable(target, method, unique=True)

# non-overlapping -> at most one match per interval in target
# want exact matches -> need both left/right to match, so defer to
# left/right get_indexer, compare elementwise, equality -> match
left_indexer = self.left.get_indexer(target.left)
right_indexer = self.right.get_indexer(target.right)
indexer = np.where(left_indexer == right_indexer, left_indexer, -1)

elif not is_object_dtype(target):
elif not is_object_dtype(target.dtype):
# homogeneous scalar index: use IntervalTree
# we should always have self._should_partial_index(target) here
target = self._maybe_convert_i8(target)
indexer = self._engine.get_indexer(target.values)
else:
# heterogeneous scalar index: defer elementwise to get_loc
# we should always have self._should_partial_index(target) here
return self._get_indexer_pointwise(target)[0]

return ensure_platform_int(indexer)
Expand All @@ -671,11 +670,12 @@ def get_indexer_non_unique(self, target: Index) -> tuple[np.ndarray, np.ndarray]
# both returned ndarrays are np.intp
target = ensure_index(target)

if isinstance(target, IntervalIndex) and not self._should_compare(target):
# different closed or incompatible subtype -> no matches
if not self._should_compare(target) and not self._should_partial_index(target):
# e.g. IntervalIndex with different closed or incompatible subtype
# -> no matches
return self._get_indexer_non_comparable(target, None, unique=False)

elif is_object_dtype(target.dtype) or isinstance(target, IntervalIndex):
elif is_object_dtype(target.dtype) or not self._should_partial_index(target):
# target might contain intervals: defer elementwise to get_loc
return self._get_indexer_pointwise(target)

Expand Down