Skip to content

Commit 77d3d8f

Browse files
jbrockmendelJulianWgs
authored andcommitted
REF: de-duplicate CategoricalIndex._get_indexer (pandas-dev#42042)
1 parent e22c42e commit 77d3d8f

File tree

3 files changed

+15
-43
lines changed

3 files changed

+15
-43
lines changed

pandas/core/indexes/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -5167,6 +5167,7 @@ def set_value(self, arr, key, value):
51675167
def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]:
51685168
# both returned ndarrays are np.intp
51695169
target = ensure_index(target)
5170+
target = self._maybe_cast_listlike_indexer(target)
51705171

51715172
if not self._should_compare(target) and not is_interval_dtype(self.dtype):
51725173
# IntervalIndex get special treatment bc numeric scalars can be

pandas/core/indexes/category.py

+7-42
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,12 @@
1212

1313
from pandas._libs import index as libindex
1414
from pandas._typing import (
15-
ArrayLike,
1615
Dtype,
1716
DtypeObj,
1817
)
19-
from pandas.util._decorators import (
20-
Appender,
21-
doc,
22-
)
18+
from pandas.util._decorators import doc
2319

2420
from pandas.core.dtypes.common import (
25-
ensure_platform_int,
2621
is_categorical_dtype,
2722
is_scalar,
2823
)
@@ -41,7 +36,6 @@
4136
import pandas.core.indexes.base as ibase
4237
from pandas.core.indexes.base import (
4338
Index,
44-
_index_shared_docs,
4539
maybe_extract_name,
4640
)
4741
from pandas.core.indexes.extension import (
@@ -492,38 +486,9 @@ def _maybe_cast_indexer(self, key) -> int:
492486
return -1
493487
raise
494488

495-
def _get_indexer(
496-
self,
497-
target: Index,
498-
method: str | None = None,
499-
limit: int | None = None,
500-
tolerance=None,
501-
) -> np.ndarray:
502-
# returned ndarray is np.intp
503-
504-
if self.equals(target):
505-
return np.arange(len(self), dtype="intp")
506-
507-
return self._get_indexer_non_unique(target._values)[0]
508-
509-
@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
510-
def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]:
511-
# both returned ndarrays are np.intp
512-
target = ibase.ensure_index(target)
513-
return self._get_indexer_non_unique(target._values)
514-
515-
def _get_indexer_non_unique(
516-
self, values: ArrayLike
517-
) -> tuple[np.ndarray, np.ndarray]:
518-
# both returned ndarrays are np.intp
519-
"""
520-
get_indexer_non_unique but after unrapping the target Index object.
521-
"""
522-
# Note: we use engine.get_indexer_non_unique for get_indexer in addition
523-
# to get_indexer_non_unique because, even if `target` is unique, any
524-
# non-category entries in it will be encoded as -1 so `codes` may
525-
# not be unique.
526-
489+
def _maybe_cast_listlike_indexer(self, values) -> CategoricalIndex:
490+
if isinstance(values, CategoricalIndex):
491+
values = values._data
527492
if isinstance(values, Categorical):
528493
# Indexing on codes is more efficient if categories are the same,
529494
# so we can apply some optimizations based on the degree of
@@ -532,9 +497,9 @@ def _get_indexer_non_unique(
532497
codes = cat._codes
533498
else:
534499
codes = self.categories.get_indexer(values)
535-
536-
indexer, missing = self._engine.get_indexer_non_unique(codes)
537-
return ensure_platform_int(indexer), ensure_platform_int(missing)
500+
codes = codes.astype(self.codes.dtype, copy=False)
501+
cat = self._data._from_backing_data(codes)
502+
return type(self)._simple_new(cat)
538503

539504
# --------------------------------------------------------------------
540505

pandas/core/indexes/datetimelike.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
from pandas.core.arrays import (
4747
DatetimeArray,
48+
ExtensionArray,
4849
PeriodArray,
4950
TimedeltaArray,
5051
)
@@ -595,7 +596,12 @@ def _maybe_cast_listlike_indexer(self, keyarr):
595596
try:
596597
res = self._data._validate_listlike(keyarr, allow_object=True)
597598
except (ValueError, TypeError):
598-
res = com.asarray_tuplesafe(keyarr)
599+
if not isinstance(keyarr, ExtensionArray):
600+
# e.g. we don't want to cast DTA to ndarray[object]
601+
res = com.asarray_tuplesafe(keyarr)
602+
# TODO: com.asarray_tuplesafe shouldn't cast e.g. DatetimeArray
603+
else:
604+
res = keyarr
599605
return Index(res, dtype=res.dtype)
600606

601607

0 commit comments

Comments
 (0)