diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 718b6afb70e06..2eda54ec8d4ed 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -439,44 +439,10 @@ def _to_safe_for_reshape(self): """ convert to object if we are a categorical """ return self.astype("object") - def get_loc(self, key, method=None): - """ - Get integer location, slice or boolean mask for requested label. - - Parameters - ---------- - key : label - method : {None} - * default: exact matches only. - - Returns - ------- - loc : int if unique index, slice if monotonic index, else mask - - Raises - ------ - KeyError : if the key is not in the index - - Examples - -------- - >>> unique_index = pd.CategoricalIndex(list('abc')) - >>> unique_index.get_loc('b') - 1 - - >>> monotonic_index = pd.CategoricalIndex(list('abbc')) - >>> monotonic_index.get_loc('b') - slice(1, 3, None) - - >>> non_monotonic_index = pd.CategoricalIndex(list('abcb')) - >>> non_monotonic_index.get_loc('b') - array([False, True, False, True], dtype=bool) - """ + def _maybe_cast_indexer(self, key): code = self.categories.get_loc(key) code = self.codes.dtype.type(code) - try: - return self._engine.get_loc(code) - except KeyError: - raise KeyError(key) + return code def get_value(self, series: "Series", key: Any): """ diff --git a/pandas/tests/indexes/categorical/test_indexing.py b/pandas/tests/indexes/categorical/test_indexing.py index 6fce6542d228e..507e38d9acac2 100644 --- a/pandas/tests/indexes/categorical/test_indexing.py +++ b/pandas/tests/indexes/categorical/test_indexing.py @@ -172,6 +172,23 @@ def test_get_loc(self): with pytest.raises(KeyError, match="'c'"): i.get_loc("c") + def test_get_loc_unique(self): + cidx = pd.CategoricalIndex(list("abc")) + result = cidx.get_loc("b") + assert result == 1 + + def test_get_loc_monotonic_nonunique(self): + cidx = pd.CategoricalIndex(list("abbc")) + result = cidx.get_loc("b") + expected = slice(1, 3, None) + assert result == expected + + def test_get_loc_nonmonotonic_nonunique(self): + cidx = pd.CategoricalIndex(list("abcb")) + result = cidx.get_loc("b") + expected = np.array([False, True, False, True], dtype=bool) + tm.assert_numpy_array_equal(result, expected) + class TestGetIndexer: def test_get_indexer_base(self):