Skip to content

REF: implement Index._get_indexer_strict #42485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3456,7 +3456,7 @@ def __getitem__(self, key):
else:
if is_iterator(key):
key = list(key)
indexer = self.loc._get_listlike_indexer(key, axis=1)[1]
indexer = self.columns._get_indexer_strict(key, "columns")[1]

# take() does not accept boolean indexers
if getattr(indexer, "dtype", None) == bool:
Expand Down
83 changes: 83 additions & 0 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5390,6 +5390,89 @@ def get_indexer_for(self, target) -> np.ndarray:
indexer, _ = self.get_indexer_non_unique(target)
return indexer

def _get_indexer_strict(self, key, axis_name: str_t) -> tuple[Index, np.ndarray]:
"""
Analogue to get_indexer that raises if any elements are missing.
"""
keyarr = key
if not isinstance(keyarr, Index):
keyarr = com.asarray_tuplesafe(keyarr)

if self._index_as_unique:
indexer = self.get_indexer_for(keyarr)
keyarr = self.reindex(keyarr)[0]
else:
keyarr, indexer, new_indexer = self._reindex_non_unique(keyarr)

self._raise_if_missing(keyarr, indexer, axis_name)

if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_extension_array_dtype?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also dt64 and td64. im pretty sure we should actually be doing this (.take on L5416) unconditionally, but that breaks a few tests which i think are wrong, so thats not ready yet

needs_i8_conversion(self.dtype)
or is_categorical_dtype(self.dtype)
or is_interval_dtype(self.dtype)
):
# For CategoricalIndex take instead of reindex to preserve dtype.
# For IntervalIndex this is to map integers to the Intervals they match to.
keyarr = self.take(indexer)
if keyarr.dtype.kind in ["m", "M"]:
# DTI/TDI.take can infer a freq in some cases when we dont want one
if isinstance(key, list) or (
isinstance(key, type(self))
# "Index" has no attribute "freq"
and key.freq is None # type: ignore[attr-defined]
):
keyarr = keyarr._with_freq(None)

return keyarr, indexer

def _raise_if_missing(self, key, indexer, axis_name: str_t):
"""
Check that indexer can be used to return a result.

e.g. at least one element was found,
unless the list of keys was actually empty.

Parameters
----------
key : list-like
Targeted labels (only used to show correct error message).
indexer: array-like of booleans
Indices corresponding to the key,
(with -1 indicating not found).
axis_name : str

Raises
------
KeyError
If at least one key was requested but none was found.
"""
if len(key) == 0:
return

# Count missing values
missing_mask = indexer < 0
nmissing = missing_mask.sum()

if nmissing:

# TODO: remove special-case; this is just to keep exception
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this mean?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we removed IntervalIndex._convert_listlike_indexer we put special-casing here to keep the exception message unchanged. This comment is to remove that special-casing.

# message tests from raising while debugging
use_interval_msg = is_interval_dtype(self.dtype) or (
is_categorical_dtype(self.dtype)
# "Index" has no attribute "categories" [attr-defined]
and is_interval_dtype(
self.categories.dtype # type: ignore[attr-defined]
)
)

if nmissing == len(indexer):
if use_interval_msg:
key = list(key)
raise KeyError(f"None of [{key}] are in the [{axis_name}]")

not_found = list(ensure_index(key)[missing_mask.nonzero()[0]].unique())
raise KeyError(f"{not_found} not in index")

@overload
def _get_indexer_non_comparable(
self, target: Index, method, unique: Literal[True] = ...
Expand Down
39 changes: 21 additions & 18 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,29 +2542,32 @@ def _get_values_for_loc(self, series: Series, loc, key):
new_ser = series._constructor(new_values, index=new_index, name=series.name)
return new_ser.__finalize__(series)

def _convert_listlike_indexer(self, keyarr) -> np.ndarray | None:
"""
Analogous to get_indexer when we are partial-indexing on our first level.

Parameters
----------
keyarr : Index, np.ndarray, or ExtensionArray
Indexer to convert.
def _get_indexer_strict(self, key, axis_name: str) -> tuple[Index, np.ndarray]:

Returns
-------
np.ndarray[intp] or None
"""
indexer = None
keyarr = key
if not isinstance(keyarr, Index):
keyarr = com.asarray_tuplesafe(keyarr)

# are we indexing a specific level
if len(keyarr) and not isinstance(keyarr[0], tuple):
_, indexer = self.reindex(keyarr, level=0)

# take all
if indexer is None:
# exact match
indexer = np.arange(len(self), dtype=np.intp)
return indexer

else:
self._raise_if_missing(key, indexer, axis_name)
return self[indexer], indexer

return super()._get_indexer_strict(key, axis_name)

def _raise_if_missing(self, key, indexer, axis_name: str):
keyarr = key
if not isinstance(key, Index):
keyarr = com.asarray_tuplesafe(key)

if len(keyarr) and not isinstance(keyarr[0], tuple):
# i.e. same condition for special case in MultiIndex._get_indexer_strict

check = self.levels[0].get_indexer(keyarr)
mask = check == -1
Expand All @@ -2574,8 +2577,8 @@ def _convert_listlike_indexer(self, keyarr) -> np.ndarray | None:
# We get here when levels still contain values which are not
# actually in Index anymore
raise KeyError(f"{keyarr} not in index")

return indexer
else:
return super()._raise_if_missing(key, indexer, axis_name)

def _get_partial_string_timestamp_match_key(self, key):
"""
Expand Down
90 changes: 2 additions & 88 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
is_object_dtype,
is_scalar,
is_sequence,
needs_i8_conversion,
)
from pandas.core.dtypes.concat import concat_compat
from pandas.core.dtypes.generic import (
Expand All @@ -56,11 +55,8 @@
length_of_indexer,
)
from pandas.core.indexes.api import (
CategoricalIndex,
Index,
IntervalIndex,
MultiIndex,
ensure_index,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1300,94 +1296,12 @@ def _get_listlike_indexer(self, key, axis: int):
Indexer for the return object, -1 denotes keys not found.
"""
ax = self.obj._get_axis(axis)
axis_name = self.obj._get_axis_name(axis)

keyarr = key
if not isinstance(keyarr, Index):
keyarr = com.asarray_tuplesafe(keyarr)

if isinstance(ax, MultiIndex):
# get_indexer expects a MultiIndex or sequence of tuples, but
# we may be doing partial-indexing, so need an extra check

# Have the index compute an indexer or return None
# if it cannot handle:
indexer = ax._convert_listlike_indexer(keyarr)
# We only act on all found values:
if indexer is not None and (indexer != -1).all():
# _validate_read_indexer is a no-op if no -1s, so skip
return ax[indexer], indexer

if ax._index_as_unique:
indexer = ax.get_indexer_for(keyarr)
keyarr = ax.reindex(keyarr)[0]
else:
keyarr, indexer, new_indexer = ax._reindex_non_unique(keyarr)

self._validate_read_indexer(keyarr, indexer, axis)

if needs_i8_conversion(ax.dtype) or isinstance(
ax, (IntervalIndex, CategoricalIndex)
):
# For CategoricalIndex take instead of reindex to preserve dtype.
# For IntervalIndex this is to map integers to the Intervals they match to.
keyarr = ax.take(indexer)
if keyarr.dtype.kind in ["m", "M"]:
# DTI/TDI.take can infer a freq in some cases when we dont want one
if isinstance(key, list) or (
isinstance(key, type(ax)) and key.freq is None
):
keyarr = keyarr._with_freq(None)
keyarr, indexer = ax._get_indexer_strict(key, axis_name)

return keyarr, indexer

def _validate_read_indexer(self, key, indexer, axis: int):
"""
Check that indexer can be used to return a result.

e.g. at least one element was found,
unless the list of keys was actually empty.

Parameters
----------
key : list-like
Targeted labels (only used to show correct error message).
indexer: array-like of booleans
Indices corresponding to the key,
(with -1 indicating not found).
axis : int
Dimension on which the indexing is being made.

Raises
------
KeyError
If at least one key was requested but none was found.
"""
if len(key) == 0:
return

# Count missing values:
missing_mask = indexer < 0
missing = (missing_mask).sum()

if missing:
ax = self.obj._get_axis(axis)

# TODO: remove special-case; this is just to keep exception
# message tests from raising while debugging
use_interval_msg = isinstance(ax, IntervalIndex) or (
isinstance(ax, CategoricalIndex)
and isinstance(ax.categories, IntervalIndex)
)

if missing == len(indexer):
axis_name = self.obj._get_axis_name(axis)
if use_interval_msg:
key = list(key)
raise KeyError(f"None of [{key}] are in the [{axis_name}]")

not_found = list(ensure_index(key)[missing_mask.nonzero()[0]].unique())
raise KeyError(f"{not_found} not in index")


@doc(IndexingMixin.iloc)
class _iLocIndexer(_LocationIndexer):
Expand Down