Skip to content

TYP: stronger typing in libindex #40465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,11 @@ cdef class IndexEngine:
self.monotonic_inc = 0
self.monotonic_dec = 0

def get_indexer(self, values):
def get_indexer(self, ndarray values):
self._ensure_mapping_populated()
return self.mapping.lookup(values)

def get_indexer_non_unique(self, targets):
def get_indexer_non_unique(self, ndarray targets):
"""
Return an indexer suitable for taking from a non unique index
return the labels in the same order as the target
Expand Down Expand Up @@ -451,11 +451,11 @@ cdef class DatetimeEngine(Int64Engine):
except KeyError:
raise KeyError(val)

def get_indexer_non_unique(self, targets):
def get_indexer_non_unique(self, ndarray targets):
# we may get datetime64[ns] or timedelta64[ns], cast these to int64
return super().get_indexer_non_unique(targets.view("i8"))

def get_indexer(self, values):
def get_indexer(self, ndarray values):
self._ensure_mapping_populated()
if values.dtype != self._get_box_dtype():
return np.repeat(-1, len(values)).astype('i4')
Expand Down Expand Up @@ -594,15 +594,15 @@ cdef class BaseMultiIndexCodesEngine:
in zip(self.levels, zip(*target))]
return self._codes_to_ints(np.array(level_codes, dtype='uint64').T)

def get_indexer_no_fill(self, object target) -> np.ndarray:
def get_indexer(self, ndarray[object] target) -> np.ndarray:
"""
Returns an array giving the positions of each value of `target` in
`self.values`, where -1 represents a value in `target` which does not
appear in `self.values`

Parameters
----------
target : list-like of keys
target : ndarray[object]
Each key is a tuple, with a label for each level of the index

Returns
Expand All @@ -613,8 +613,8 @@ cdef class BaseMultiIndexCodesEngine:
lab_ints = self._extract_level_codes(target)
return self._base.get_indexer(self, lab_ints)

def get_indexer(self, object target, object values = None,
object method = None, object limit = None) -> np.ndarray:
def get_indexer_with_fill(self, ndarray target, ndarray values,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring update for target

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

str method, object limit) -> np.ndarray:
"""
Returns an array giving the positions of each value of `target` in
`values`, where -1 represents a value in `target` which does not
Expand All @@ -630,25 +630,22 @@ cdef class BaseMultiIndexCodesEngine:

Parameters
----------
target: list-like of tuples
target: ndarray[object] of tuples
need not be sorted, but all must have the same length, which must be
the same as the length of all tuples in `values`
values : list-like of tuples
values : ndarray[object] of tuples
must be sorted and all have the same length. Should be the set of
the MultiIndex's values. Needed only if `method` is not None
method: string
"backfill" or "pad"
limit: int, optional
limit: int or None
if provided, limit the number of fills to this value

Returns
-------
np.ndarray[int64_t, ndim=1] of the indexer of `target` into `values`,
filled with the `method` (and optionally `limit`) specified
"""
if method is None:
return self.get_indexer_no_fill(target)

assert method in ("backfill", "pad")
cdef:
int64_t i, j, next_code
Expand All @@ -658,8 +655,8 @@ cdef class BaseMultiIndexCodesEngine:
ndarray[int64_t, ndim=1] new_codes, new_target_codes
ndarray[int64_t, ndim=1] sorted_indexer

target_order = np.argsort(target.values).astype('int64')
target_values = target.values[target_order]
target_order = np.argsort(target).astype('int64')
target_values = target[target_order]
num_values, num_target_values = len(values), len(target_values)
new_codes, new_target_codes = (
np.empty((num_values,)).astype('int64'),
Expand Down Expand Up @@ -718,7 +715,7 @@ cdef class BaseMultiIndexCodesEngine:

return self._base.get_loc(self, lab_int)

def get_indexer_non_unique(self, object target):
def get_indexer_non_unique(self, ndarray target):
# This needs to be overridden just because the default one works on
# target._values, and target can be itself a MultiIndex.

Expand Down
20 changes: 15 additions & 5 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3378,7 +3378,11 @@ def get_loc(self, key, method=None, tolerance=None):
@Appender(_index_shared_docs["get_indexer"] % _index_doc_kwargs)
@final
def get_indexer(
self, target, method=None, limit=None, tolerance=None
self,
target,
method: Optional[str_t] = None,
limit: Optional[int] = None,
tolerance=None,
) -> np.ndarray:

method = missing.clean_reindex_fill_method(method)
Expand All @@ -3403,7 +3407,11 @@ def get_indexer(
return self._get_indexer(target, method, limit, tolerance)

def _get_indexer(
self, target: Index, method=None, limit=None, tolerance=None
self,
target: Index,
method: Optional[str_t] = None,
limit: Optional[int] = None,
tolerance=None,
) -> np.ndarray:
if tolerance is not None:
tolerance = self._convert_tolerance(tolerance, target)
Expand Down Expand Up @@ -3467,7 +3475,7 @@ def _convert_tolerance(

@final
def _get_fill_indexer(
self, target: Index, method: str_t, limit=None, tolerance=None
self, target: Index, method: str_t, limit: Optional[int] = None, tolerance=None
) -> np.ndarray:

target_values = target._get_engine_target()
Expand All @@ -3487,7 +3495,7 @@ def _get_fill_indexer(

@final
def _get_fill_indexer_searchsorted(
self, target: Index, method: str_t, limit=None
self, target: Index, method: str_t, limit: Optional[int] = None
) -> np.ndarray:
"""
Fallback pad/backfill get_indexer that works for monotonic decreasing
Expand Down Expand Up @@ -3520,7 +3528,9 @@ def _get_fill_indexer_searchsorted(
return indexer

@final
def _get_nearest_indexer(self, target: Index, limit, tolerance) -> np.ndarray:
def _get_nearest_indexer(
self, target: Index, limit: Optional[int], tolerance
) -> np.ndarray:
"""
Get the indexer for the nearest index labels; requires an index with
values that can be subtracted from each other (e.g., not strings or
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,11 @@ def _maybe_cast_indexer(self, key) -> int:
return self._data._unbox_scalar(key)

def _get_indexer(
self, target: Index, method=None, limit=None, tolerance=None
self,
target: Index,
method: Optional[str] = None,
limit: Optional[int] = None,
tolerance=None,
) -> np.ndarray:

if self.equals(target):
Expand Down
12 changes: 8 additions & 4 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2676,7 +2676,11 @@ def _get_partial_string_timestamp_match_key(self, key):
return key

def _get_indexer(
self, target: Index, method=None, limit=None, tolerance=None
self,
target: Index,
method: Optional[str] = None,
limit: Optional[int] = None,
tolerance=None,
) -> np.ndarray:

# empty indexer
Expand All @@ -2699,16 +2703,16 @@ def _get_indexer(
raise NotImplementedError(
"tolerance not implemented yet for MultiIndex"
)
indexer = self._engine.get_indexer(
values=self._values, target=target, method=method, limit=limit
indexer = self._engine.get_indexer_with_fill(
target=target._values, values=self._values, method=method, limit=limit
)
elif method == "nearest":
raise NotImplementedError(
"method='nearest' not implemented yet "
"for MultiIndex; see GitHub issue 9365"
)
else:
indexer = self._engine.get_indexer(target)
indexer = self._engine.get_indexer(target._values)

return ensure_platform_int(indexer)

Expand Down
8 changes: 7 additions & 1 deletion pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,13 @@ def get_loc(self, key, method=None, tolerance=None):
raise KeyError(key)
return super().get_loc(key, method=method, tolerance=tolerance)

def _get_indexer(self, target: Index, method=None, limit=None, tolerance=None):
def _get_indexer(
self,
target: Index,
method: Optional[str] = None,
limit: Optional[int] = None,
tolerance=None,
):
if com.any_not_none(method, tolerance, limit):
return super()._get_indexer(
target, method=method, tolerance=tolerance, limit=limit
Expand Down