Skip to content

Commit 6a683a2

Browse files
authored
PERF: MultiIndex.get_indexer (#43370)
1 parent 37bd4dc commit 6a683a2

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

pandas/_libs/index.pyx

+8-9
Original file line numberDiff line numberDiff line change
@@ -603,35 +603,34 @@ cdef class BaseMultiIndexCodesEngine:
603603
def _codes_to_ints(self, ndarray[uint64_t] codes) -> np.ndarray:
604604
raise NotImplementedError("Implemented by subclass")
605605

606-
def _extract_level_codes(self, ndarray[object] target) -> np.ndarray:
606+
def _extract_level_codes(self, target) -> np.ndarray:
607607
"""
608608
Map the requested list of (tuple) keys to their integer representations
609609
for searching in the underlying integer index.
610610

611611
Parameters
612612
----------
613-
target : ndarray[object]
614-
Each key is a tuple, with a label for each level of the index.
613+
target : MultiIndex
615614

616615
Returns
617616
------
618617
int_keys : 1-dimensional array of dtype uint64 or object
619618
Integers representing one combination each
620619
"""
620+
zt = [target._get_level_values(i) for i in range(target.nlevels)]
621621
level_codes = [lev.get_indexer(codes) + 1 for lev, codes
622-
in zip(self.levels, zip(*target))]
622+
in zip(self.levels, zt)]
623623
return self._codes_to_ints(np.array(level_codes, dtype='uint64').T)
624624

625-
def get_indexer(self, ndarray[object] target) -> np.ndarray:
625+
def get_indexer(self, target) -> np.ndarray:
626626
"""
627627
Returns an array giving the positions of each value of `target` in
628628
`self.values`, where -1 represents a value in `target` which does not
629629
appear in `self.values`
630630

631631
Parameters
632632
----------
633-
target : ndarray[object]
634-
Each key is a tuple, with a label for each level of the index
633+
target : MultiIndex
635634

636635
Returns
637636
-------
@@ -742,8 +741,8 @@ cdef class BaseMultiIndexCodesEngine:
742741

743742
return self._base.get_loc(self, lab_int)
744743

745-
def get_indexer_non_unique(self, ndarray[object] target):
746-
744+
def get_indexer_non_unique(self, target):
745+
# target: MultiIndex
747746
lab_ints = self._extract_level_codes(target)
748747
indexer = self._base.get_indexer_non_unique(self, lab_ints)
749748

pandas/core/indexes/base.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3618,7 +3618,12 @@ def _get_indexer(
36183618
elif method == "nearest":
36193619
indexer = self._get_nearest_indexer(target, limit, tolerance)
36203620
else:
3621-
indexer = self._engine.get_indexer(target._get_engine_target())
3621+
tgt_values = target._get_engine_target()
3622+
if target._is_multi and self._is_multi:
3623+
# error: Incompatible types in assignment (expression has type
3624+
# "Index", variable has type "ndarray[Any, Any]")
3625+
tgt_values = target # type: ignore[assignment]
3626+
indexer = self._engine.get_indexer(tgt_values)
36223627

36233628
return ensure_platform_int(indexer)
36243629

@@ -5459,6 +5464,8 @@ def get_indexer_non_unique(
54595464
# Note: _maybe_promote ensures we never get here with MultiIndex
54605465
# self and non-Multi target
54615466
tgt_values = target._get_engine_target()
5467+
if self._is_multi and target._is_multi:
5468+
tgt_values = target
54625469

54635470
indexer, missing = self._engine.get_indexer_non_unique(tgt_values)
54645471
return ensure_platform_int(indexer), ensure_platform_int(missing)

0 commit comments

Comments
 (0)