Skip to content

Commit 2b32e41

Browse files
JustinZhengBCjreback
authored andcommitted
BUG-20629 allow .at accessor with CategoricalIndex (#26298)
1 parent 279753c commit 2b32e41

File tree

5 files changed

+48
-9
lines changed

5 files changed

+48
-9
lines changed

doc/source/whatsnew/v0.25.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ Bug Fixes
301301
Categorical
302302
^^^^^^^^^^^
303303

304-
-
304+
- Bug in :func:`DataFrame.at` and :func:`Series.at` that would raise exception if the index was a :class:`CategoricalIndex` (:issue:`20629`)
305305
-
306306
-
307307

pandas/_typing.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@
88
from pandas._libs.tslibs.timedeltas import Timedelta
99

1010
from pandas.core.dtypes.dtypes import ExtensionDtype
11-
from pandas.core.dtypes.generic import ABCExtensionArray
11+
from pandas.core.dtypes.generic import (
12+
ABCExtensionArray, ABCIndexClass, ABCSeries, ABCSparseSeries)
1213

14+
AnyArrayLike = Union[ABCExtensionArray,
15+
ABCIndexClass,
16+
ABCSeries,
17+
ABCSparseSeries,
18+
np.ndarray]
1319
ArrayLike = Union[ABCExtensionArray, np.ndarray]
1420
DatetimeLikeScalar = Type[Union[Period, Timestamp, Timedelta]]
1521
Dtype = Union[str, np.dtype, ExtensionDtype]

pandas/core/frame.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -2694,13 +2694,19 @@ def _get_value(self, index, col, takeable=False):
26942694

26952695
try:
26962696
return engine.get_value(series._values, index)
2697+
except KeyError:
2698+
# GH 20629
2699+
if self.index.nlevels > 1:
2700+
# partial indexing forbidden
2701+
raise
26972702
except (TypeError, ValueError):
2703+
pass
26982704

2699-
# we cannot handle direct indexing
2700-
# use positional
2701-
col = self.columns.get_loc(col)
2702-
index = self.index.get_loc(index)
2703-
return self._get_value(index, col, takeable=True)
2705+
# we cannot handle direct indexing
2706+
# use positional
2707+
col = self.columns.get_loc(col)
2708+
index = self.index.get_loc(index)
2709+
return self._get_value(index, col, takeable=True)
27042710
_get_value.__doc__ = get_value.__doc__
27052711

27062712
def set_value(self, index, col, value, takeable=False):

pandas/core/indexes/category.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
from typing import Any
23
import warnings
34

45
import numpy as np
@@ -17,6 +18,7 @@
1718
from pandas.core.dtypes.generic import ABCCategorical, ABCSeries
1819
from pandas.core.dtypes.missing import isna
1920

21+
from pandas._typing import AnyArrayLike
2022
from pandas.core import accessor
2123
from pandas.core.algorithms import take_1d
2224
from pandas.core.arrays.categorical import Categorical, contains
@@ -494,16 +496,31 @@ def get_loc(self, key, method=None):
494496
except KeyError:
495497
raise KeyError(key)
496498

497-
def get_value(self, series, key):
499+
def get_value(self,
500+
series: AnyArrayLike,
501+
key: Any):
498502
"""
499503
Fast lookup of value from 1-dimensional ndarray. Only use this if you
500504
know what you're doing
505+
506+
Parameters
507+
----------
508+
series : Series, ExtensionArray, Index, or ndarray
509+
1-dimensional array to take values from
510+
key: : scalar
511+
The value of this index at the position of the desired value,
512+
otherwise the positional index of the desired value
513+
514+
Returns
515+
-------
516+
Any
517+
The element of the series at the position indicated by the key
501518
"""
502519
try:
503520
k = com.values_from_object(key)
504521
k = self._convert_scalar_indexer(k, kind='getitem')
505522
indexer = self.get_loc(k)
506-
return series.iloc[indexer]
523+
return series.take([indexer])[0]
507524
except (KeyError, TypeError):
508525
pass
509526

pandas/tests/indexing/test_categorical.py

+10
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,16 @@ def test_loc_slice(self):
638638
# expected = df.iloc[[1,2,3,4]]
639639
# assert_frame_equal(result, expected)
640640

641+
def test_loc_and_at_with_categorical_index(self):
642+
# GH 20629
643+
s = Series([1, 2, 3], index=pd.CategoricalIndex(["A", "B", "C"]))
644+
assert s.loc['A'] == 1
645+
assert s.at['A'] == 1
646+
df = DataFrame([[1, 2], [3, 4], [5, 6]],
647+
index=pd.CategoricalIndex(["A", "B", "C"]))
648+
assert df.loc['B', 1] == 4
649+
assert df.at['B', 1] == 4
650+
641651
def test_boolean_selection(self):
642652

643653
df3 = self.df3

0 commit comments

Comments
 (0)