Skip to content

Commit 985c5fa

Browse files
jbrockmendelyeshsurya
authored andcommitted
TYP: get_indexer (pandas-dev#40612)
* TYP: get_indexer * update per discussion in pandas-dev#40612 * one more overload * pre-commit fixup
1 parent 6882acf commit 985c5fa

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

pandas/core/indexes/base.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Sequence,
1414
TypeVar,
1515
cast,
16+
overload,
1617
)
1718
import warnings
1819

@@ -159,6 +160,8 @@
159160
)
160161

161162
if TYPE_CHECKING:
163+
from typing import Literal
164+
162165
from pandas import (
163166
CategoricalIndex,
164167
DataFrame,
@@ -5212,7 +5215,8 @@ def set_value(self, arr, key, value):
52125215
"""
52135216

52145217
@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
5215-
def get_indexer_non_unique(self, target):
5218+
def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]:
5219+
# both returned ndarrays are np.intp
52165220
target = ensure_index(target)
52175221

52185222
if not self._should_compare(target) and not is_interval_dtype(self.dtype):
@@ -5236,7 +5240,7 @@ def get_indexer_non_unique(self, target):
52365240
tgt_values = target._get_engine_target()
52375241

52385242
indexer, missing = self._engine.get_indexer_non_unique(tgt_values)
5239-
return ensure_platform_int(indexer), missing
5243+
return ensure_platform_int(indexer), ensure_platform_int(missing)
52405244

52415245
@final
52425246
def get_indexer_for(self, target, **kwargs) -> np.ndarray:
@@ -5256,8 +5260,31 @@ def get_indexer_for(self, target, **kwargs) -> np.ndarray:
52565260
indexer, _ = self.get_indexer_non_unique(target)
52575261
return indexer
52585262

5263+
@overload
5264+
def _get_indexer_non_comparable(
5265+
self, target: Index, method, unique: Literal[True] = ...
5266+
) -> np.ndarray:
5267+
# returned ndarray is np.intp
5268+
...
5269+
5270+
@overload
5271+
def _get_indexer_non_comparable(
5272+
self, target: Index, method, unique: Literal[False]
5273+
) -> tuple[np.ndarray, np.ndarray]:
5274+
# both returned ndarrays are np.intp
5275+
...
5276+
5277+
@overload
5278+
def _get_indexer_non_comparable(
5279+
self, target: Index, method, unique: bool = True
5280+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
5281+
# any returned ndarrays are np.intp
5282+
...
5283+
52595284
@final
5260-
def _get_indexer_non_comparable(self, target: Index, method, unique: bool = True):
5285+
def _get_indexer_non_comparable(
5286+
self, target: Index, method, unique: bool = True
5287+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
52615288
"""
52625289
Called from get_indexer or get_indexer_non_unique when the target
52635290
is of a non-comparable dtype.

pandas/core/indexes/category.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -491,18 +491,23 @@ def _get_indexer(
491491
limit: int | None = None,
492492
tolerance=None,
493493
) -> np.ndarray:
494+
# returned ndarray is np.intp
494495

495496
if self.equals(target):
496497
return np.arange(len(self), dtype="intp")
497498

498499
return self._get_indexer_non_unique(target._values)[0]
499500

500501
@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
501-
def get_indexer_non_unique(self, target):
502+
def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]:
503+
# both returned ndarrays are np.intp
502504
target = ibase.ensure_index(target)
503505
return self._get_indexer_non_unique(target._values)
504506

505-
def _get_indexer_non_unique(self, values: ArrayLike):
507+
def _get_indexer_non_unique(
508+
self, values: ArrayLike
509+
) -> tuple[np.ndarray, np.ndarray]:
510+
# both returned ndarrays are np.intp
506511
"""
507512
get_indexer_non_unique but after unrapping the target Index object.
508513
"""
@@ -521,7 +526,7 @@ def _get_indexer_non_unique(self, values: ArrayLike):
521526
codes = self.categories.get_indexer(values)
522527

523528
indexer, missing = self._engine.get_indexer_non_unique(codes)
524-
return ensure_platform_int(indexer), missing
529+
return ensure_platform_int(indexer), ensure_platform_int(missing)
525530

526531
@doc(Index._convert_list_indexer)
527532
def _convert_list_indexer(self, keyarr):

0 commit comments

Comments
 (0)