Skip to content

Commit b4c554f

Browse files
authored
TYP: stronger typing in libindex (#40465)
1 parent 3a9c94b commit b4c554f

File tree

5 files changed

+49
-28
lines changed

5 files changed

+49
-28
lines changed

pandas/_libs/index.pyx

+14-17
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,11 @@ cdef class IndexEngine:
259259
self.monotonic_inc = 0
260260
self.monotonic_dec = 0
261261

262-
def get_indexer(self, values):
262+
def get_indexer(self, ndarray values):
263263
self._ensure_mapping_populated()
264264
return self.mapping.lookup(values)
265265

266-
def get_indexer_non_unique(self, targets):
266+
def get_indexer_non_unique(self, ndarray targets):
267267
"""
268268
Return an indexer suitable for taking from a non unique index
269269
return the labels in the same order as the target
@@ -451,11 +451,11 @@ cdef class DatetimeEngine(Int64Engine):
451451
except KeyError:
452452
raise KeyError(val)
453453

454-
def get_indexer_non_unique(self, targets):
454+
def get_indexer_non_unique(self, ndarray targets):
455455
# we may get datetime64[ns] or timedelta64[ns], cast these to int64
456456
return super().get_indexer_non_unique(targets.view("i8"))
457457

458-
def get_indexer(self, values):
458+
def get_indexer(self, ndarray values):
459459
self._ensure_mapping_populated()
460460
if values.dtype != self._get_box_dtype():
461461
return np.repeat(-1, len(values)).astype('i4')
@@ -594,15 +594,15 @@ cdef class BaseMultiIndexCodesEngine:
594594
in zip(self.levels, zip(*target))]
595595
return self._codes_to_ints(np.array(level_codes, dtype='uint64').T)
596596

597-
def get_indexer_no_fill(self, object target) -> np.ndarray:
597+
def get_indexer(self, ndarray[object] target) -> np.ndarray:
598598
"""
599599
Returns an array giving the positions of each value of `target` in
600600
`self.values`, where -1 represents a value in `target` which does not
601601
appear in `self.values`
602602

603603
Parameters
604604
----------
605-
target : list-like of keys
605+
target : ndarray[object]
606606
Each key is a tuple, with a label for each level of the index
607607

608608
Returns
@@ -613,8 +613,8 @@ cdef class BaseMultiIndexCodesEngine:
613613
lab_ints = self._extract_level_codes(target)
614614
return self._base.get_indexer(self, lab_ints)
615615

616-
def get_indexer(self, object target, object values = None,
617-
object method = None, object limit = None) -> np.ndarray:
616+
def get_indexer_with_fill(self, ndarray target, ndarray values,
617+
str method, object limit) -> np.ndarray:
618618
"""
619619
Returns an array giving the positions of each value of `target` in
620620
`values`, where -1 represents a value in `target` which does not
@@ -630,25 +630,22 @@ cdef class BaseMultiIndexCodesEngine:
630630

631631
Parameters
632632
----------
633-
target: list-like of tuples
633+
target: ndarray[object] of tuples
634634
need not be sorted, but all must have the same length, which must be
635635
the same as the length of all tuples in `values`
636-
values : list-like of tuples
636+
values : ndarray[object] of tuples
637637
must be sorted and all have the same length. Should be the set of
638638
the MultiIndex's values. Needed only if `method` is not None
639639
method: string
640640
"backfill" or "pad"
641-
limit: int, optional
641+
limit: int or None
642642
if provided, limit the number of fills to this value
643643

644644
Returns
645645
-------
646646
np.ndarray[int64_t, ndim=1] of the indexer of `target` into `values`,
647647
filled with the `method` (and optionally `limit`) specified
648648
"""
649-
if method is None:
650-
return self.get_indexer_no_fill(target)
651-
652649
assert method in ("backfill", "pad")
653650
cdef:
654651
int64_t i, j, next_code
@@ -658,8 +655,8 @@ cdef class BaseMultiIndexCodesEngine:
658655
ndarray[int64_t, ndim=1] new_codes, new_target_codes
659656
ndarray[int64_t, ndim=1] sorted_indexer
660657

661-
target_order = np.argsort(target.values).astype('int64')
662-
target_values = target.values[target_order]
658+
target_order = np.argsort(target).astype('int64')
659+
target_values = target[target_order]
663660
num_values, num_target_values = len(values), len(target_values)
664661
new_codes, new_target_codes = (
665662
np.empty((num_values,)).astype('int64'),
@@ -718,7 +715,7 @@ cdef class BaseMultiIndexCodesEngine:
718715

719716
return self._base.get_loc(self, lab_int)
720717

721-
def get_indexer_non_unique(self, object target):
718+
def get_indexer_non_unique(self, ndarray target):
722719
# This needs to be overridden just because the default one works on
723720
# target._values, and target can be itself a MultiIndex.
724721

pandas/core/indexes/base.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -3378,7 +3378,11 @@ def get_loc(self, key, method=None, tolerance=None):
33783378
@Appender(_index_shared_docs["get_indexer"] % _index_doc_kwargs)
33793379
@final
33803380
def get_indexer(
3381-
self, target, method=None, limit=None, tolerance=None
3381+
self,
3382+
target,
3383+
method: Optional[str_t] = None,
3384+
limit: Optional[int] = None,
3385+
tolerance=None,
33823386
) -> np.ndarray:
33833387

33843388
method = missing.clean_reindex_fill_method(method)
@@ -3403,7 +3407,11 @@ def get_indexer(
34033407
return self._get_indexer(target, method, limit, tolerance)
34043408

34053409
def _get_indexer(
3406-
self, target: Index, method=None, limit=None, tolerance=None
3410+
self,
3411+
target: Index,
3412+
method: Optional[str_t] = None,
3413+
limit: Optional[int] = None,
3414+
tolerance=None,
34073415
) -> np.ndarray:
34083416
if tolerance is not None:
34093417
tolerance = self._convert_tolerance(tolerance, target)
@@ -3467,7 +3475,7 @@ def _convert_tolerance(
34673475

34683476
@final
34693477
def _get_fill_indexer(
3470-
self, target: Index, method: str_t, limit=None, tolerance=None
3478+
self, target: Index, method: str_t, limit: Optional[int] = None, tolerance=None
34713479
) -> np.ndarray:
34723480

34733481
target_values = target._get_engine_target()
@@ -3487,7 +3495,7 @@ def _get_fill_indexer(
34873495

34883496
@final
34893497
def _get_fill_indexer_searchsorted(
3490-
self, target: Index, method: str_t, limit=None
3498+
self, target: Index, method: str_t, limit: Optional[int] = None
34913499
) -> np.ndarray:
34923500
"""
34933501
Fallback pad/backfill get_indexer that works for monotonic decreasing
@@ -3520,7 +3528,9 @@ def _get_fill_indexer_searchsorted(
35203528
return indexer
35213529

35223530
@final
3523-
def _get_nearest_indexer(self, target: Index, limit, tolerance) -> np.ndarray:
3531+
def _get_nearest_indexer(
3532+
self, target: Index, limit: Optional[int], tolerance
3533+
) -> np.ndarray:
35243534
"""
35253535
Get the indexer for the nearest index labels; requires an index with
35263536
values that can be subtracted from each other (e.g., not strings or

pandas/core/indexes/category.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,11 @@ def _maybe_cast_indexer(self, key) -> int:
492492
return self._data._unbox_scalar(key)
493493

494494
def _get_indexer(
495-
self, target: Index, method=None, limit=None, tolerance=None
495+
self,
496+
target: Index,
497+
method: Optional[str] = None,
498+
limit: Optional[int] = None,
499+
tolerance=None,
496500
) -> np.ndarray:
497501

498502
if self.equals(target):

pandas/core/indexes/multi.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -2676,7 +2676,11 @@ def _get_partial_string_timestamp_match_key(self, key):
26762676
return key
26772677

26782678
def _get_indexer(
2679-
self, target: Index, method=None, limit=None, tolerance=None
2679+
self,
2680+
target: Index,
2681+
method: Optional[str] = None,
2682+
limit: Optional[int] = None,
2683+
tolerance=None,
26802684
) -> np.ndarray:
26812685

26822686
# empty indexer
@@ -2699,16 +2703,16 @@ def _get_indexer(
26992703
raise NotImplementedError(
27002704
"tolerance not implemented yet for MultiIndex"
27012705
)
2702-
indexer = self._engine.get_indexer(
2703-
values=self._values, target=target, method=method, limit=limit
2706+
indexer = self._engine.get_indexer_with_fill(
2707+
target=target._values, values=self._values, method=method, limit=limit
27042708
)
27052709
elif method == "nearest":
27062710
raise NotImplementedError(
27072711
"method='nearest' not implemented yet "
27082712
"for MultiIndex; see GitHub issue 9365"
27092713
)
27102714
else:
2711-
indexer = self._engine.get_indexer(target)
2715+
indexer = self._engine.get_indexer(target._values)
27122716

27132717
return ensure_platform_int(indexer)
27142718

pandas/core/indexes/range.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,13 @@ def get_loc(self, key, method=None, tolerance=None):
396396
raise KeyError(key)
397397
return super().get_loc(key, method=method, tolerance=tolerance)
398398

399-
def _get_indexer(self, target: Index, method=None, limit=None, tolerance=None):
399+
def _get_indexer(
400+
self,
401+
target: Index,
402+
method: Optional[str] = None,
403+
limit: Optional[int] = None,
404+
tolerance=None,
405+
):
400406
if com.any_not_none(method, tolerance, limit):
401407
return super()._get_indexer(
402408
target, method=method, tolerance=tolerance, limit=limit

0 commit comments

Comments
 (0)