From 2e5c986427fa28e672884caeff846d6e4fdac28a Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 12 Jun 2021 21:48:51 -0700 Subject: [PATCH 1/2] REF: de-duplicate CategoricalIndex._get_indexer --- pandas/core/indexes/category.py | 43 +++++++++------------------------ 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 1bda05f3ce5df..b91fe2080218f 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -492,38 +492,7 @@ def _maybe_cast_indexer(self, key) -> int: return -1 raise - def _get_indexer( - self, - target: Index, - method: str | None = None, - limit: int | None = None, - tolerance=None, - ) -> np.ndarray: - # returned ndarray is np.intp - - if self.equals(target): - return np.arange(len(self), dtype="intp") - - return self._get_indexer_non_unique(target._values)[0] - - @Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs) - def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]: - # both returned ndarrays are np.intp - target = ibase.ensure_index(target) - return self._get_indexer_non_unique(target._values) - - def _get_indexer_non_unique( - self, values: ArrayLike - ) -> tuple[np.ndarray, np.ndarray]: - # both returned ndarrays are np.intp - """ - get_indexer_non_unique but after unrapping the target Index object. - """ - # Note: we use engine.get_indexer_non_unique for get_indexer in addition - # to get_indexer_non_unique because, even if `target` is unique, any - # non-category entries in it will be encoded as -1 so `codes` may - # not be unique. - + def _maybe_cast_listlike_indexer(self, values: ArrayLike) -> CategoricalIndex: if isinstance(values, Categorical): # Indexing on codes is more efficient if categories are the same, # so we can apply some optimizations based on the degree of @@ -532,6 +501,16 @@ def _get_indexer_non_unique( codes = cat._codes else: codes = self.categories.get_indexer(values) + codes = codes.astype(self.codes.dtype, copy=False) + cat = self._data._from_backing_data(codes) + return type(self)._simple_new(cat) + + @Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs) + def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]: + # both returned ndarrays are np.intp + target = ibase.ensure_index(target) + ci = self._maybe_cast_listlike_indexer(target._values) + codes = ci._get_engine_target() indexer, missing = self._engine.get_indexer_non_unique(codes) return ensure_platform_int(indexer), ensure_platform_int(missing) From 79140c21703992eb24562ac42b63254f4125c368 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 13 Jun 2021 15:31:45 -0700 Subject: [PATCH 2/2] REF: de-duplicate get_indexer_non_unique --- pandas/core/indexes/base.py | 1 + pandas/core/indexes/category.py | 22 ++++------------------ pandas/core/indexes/datetimelike.py | 8 +++++++- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 4de95079f6480..69d53b04cf9f6 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -5167,6 +5167,7 @@ def set_value(self, arr, key, value): def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]: # both returned ndarrays are np.intp target = ensure_index(target) + target = self._maybe_cast_listlike_indexer(target) if not self._should_compare(target) and not is_interval_dtype(self.dtype): # IntervalIndex get special treatment bc numeric scalars can be diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index b91fe2080218f..2ebc39664ad60 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -12,17 +12,12 @@ from pandas._libs import index as libindex from pandas._typing import ( - ArrayLike, Dtype, DtypeObj, ) -from pandas.util._decorators import ( - Appender, - doc, -) +from pandas.util._decorators import doc from pandas.core.dtypes.common import ( - ensure_platform_int, is_categorical_dtype, is_scalar, ) @@ -41,7 +36,6 @@ import pandas.core.indexes.base as ibase from pandas.core.indexes.base import ( Index, - _index_shared_docs, maybe_extract_name, ) from pandas.core.indexes.extension import ( @@ -492,7 +486,9 @@ def _maybe_cast_indexer(self, key) -> int: return -1 raise - def _maybe_cast_listlike_indexer(self, values: ArrayLike) -> CategoricalIndex: + def _maybe_cast_listlike_indexer(self, values) -> CategoricalIndex: + if isinstance(values, CategoricalIndex): + values = values._data if isinstance(values, Categorical): # Indexing on codes is more efficient if categories are the same, # so we can apply some optimizations based on the degree of @@ -505,16 +501,6 @@ def _maybe_cast_listlike_indexer(self, values: ArrayLike) -> CategoricalIndex: cat = self._data._from_backing_data(codes) return type(self)._simple_new(cat) - @Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs) - def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]: - # both returned ndarrays are np.intp - target = ibase.ensure_index(target) - ci = self._maybe_cast_listlike_indexer(target._values) - codes = ci._get_engine_target() - - indexer, missing = self._engine.get_indexer_non_unique(codes) - return ensure_platform_int(indexer), ensure_platform_int(missing) - # -------------------------------------------------------------------- def _is_comparable_dtype(self, dtype: DtypeObj) -> bool: diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index df7fae0763c42..aeef37dfb9af6 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -45,6 +45,7 @@ from pandas.core.arrays import ( DatetimeArray, + ExtensionArray, PeriodArray, TimedeltaArray, ) @@ -595,7 +596,12 @@ def _maybe_cast_listlike_indexer(self, keyarr): try: res = self._data._validate_listlike(keyarr, allow_object=True) except (ValueError, TypeError): - res = com.asarray_tuplesafe(keyarr) + if not isinstance(keyarr, ExtensionArray): + # e.g. we don't want to cast DTA to ndarray[object] + res = com.asarray_tuplesafe(keyarr) + # TODO: com.asarray_tuplesafe shouldn't cast e.g. DatetimeArray + else: + res = keyarr return Index(res, dtype=res.dtype)