Skip to content

Commit 8aa2520

Browse files
authored
REF: Use _maybe_cast_indexer and dont override Categorical.get_loc (#31681)
1 parent e1ca66b commit 8aa2520

File tree

2 files changed

+19
-36
lines changed

2 files changed

+19
-36
lines changed

pandas/core/indexes/category.py

+2-36
Original file line numberDiff line numberDiff line change
@@ -439,44 +439,10 @@ def _to_safe_for_reshape(self):
439439
""" convert to object if we are a categorical """
440440
return self.astype("object")
441441

442-
def get_loc(self, key, method=None):
443-
"""
444-
Get integer location, slice or boolean mask for requested label.
445-
446-
Parameters
447-
----------
448-
key : label
449-
method : {None}
450-
* default: exact matches only.
451-
452-
Returns
453-
-------
454-
loc : int if unique index, slice if monotonic index, else mask
455-
456-
Raises
457-
------
458-
KeyError : if the key is not in the index
459-
460-
Examples
461-
--------
462-
>>> unique_index = pd.CategoricalIndex(list('abc'))
463-
>>> unique_index.get_loc('b')
464-
1
465-
466-
>>> monotonic_index = pd.CategoricalIndex(list('abbc'))
467-
>>> monotonic_index.get_loc('b')
468-
slice(1, 3, None)
469-
470-
>>> non_monotonic_index = pd.CategoricalIndex(list('abcb'))
471-
>>> non_monotonic_index.get_loc('b')
472-
array([False, True, False, True], dtype=bool)
473-
"""
442+
def _maybe_cast_indexer(self, key):
474443
code = self.categories.get_loc(key)
475444
code = self.codes.dtype.type(code)
476-
try:
477-
return self._engine.get_loc(code)
478-
except KeyError:
479-
raise KeyError(key)
445+
return code
480446

481447
def get_value(self, series: "Series", key: Any):
482448
"""

pandas/tests/indexes/categorical/test_indexing.py

+17
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,23 @@ def test_get_loc(self):
172172
with pytest.raises(KeyError, match="'c'"):
173173
i.get_loc("c")
174174

175+
def test_get_loc_unique(self):
176+
cidx = pd.CategoricalIndex(list("abc"))
177+
result = cidx.get_loc("b")
178+
assert result == 1
179+
180+
def test_get_loc_monotonic_nonunique(self):
181+
cidx = pd.CategoricalIndex(list("abbc"))
182+
result = cidx.get_loc("b")
183+
expected = slice(1, 3, None)
184+
assert result == expected
185+
186+
def test_get_loc_nonmonotonic_nonunique(self):
187+
cidx = pd.CategoricalIndex(list("abcb"))
188+
result = cidx.get_loc("b")
189+
expected = np.array([False, True, False, True], dtype=bool)
190+
tm.assert_numpy_array_equal(result, expected)
191+
175192

176193
class TestGetIndexer:
177194
def test_get_indexer_base(self):

0 commit comments

Comments
 (0)