From 2c8af24f9815f83030b80d18e0fcaf44c1bf69cf Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Tue, 15 Jun 2021 17:03:29 -0700 Subject: [PATCH] Backport PR #41846: BUG: DataFrame.at with CategoricalIndex --- doc/source/whatsnew/v1.3.0.rst | 1 + pandas/core/frame.py | 29 ++++++++++++++++++----------- pandas/core/indexing.py | 5 ++--- pandas/tests/indexing/test_at.py | 14 ++++++++++++++ 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 414794dd6a56e..a202c6fe56642 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -1015,6 +1015,7 @@ Indexing - Bug in :meth:`DataFrame.loc.__getitem__` with :class:`MultiIndex` casting to float when at least one index column has float dtype and we retrieve a scalar (:issue:`41369`) - Bug in :meth:`DataFrame.loc` incorrectly matching non-Boolean index elements (:issue:`20432`) - Bug in :meth:`Series.__delitem__` with ``ExtensionDtype`` incorrectly casting to ``ndarray`` (:issue:`40386`) +- Bug in :meth:`DataFrame.at` with a :class:`CategoricalIndex` returning incorrect results when passed integer keys (:issue:`41846`) - Bug in :meth:`DataFrame.loc` returning a :class:`MultiIndex` in the wrong order if an indexer has duplicates (:issue:`40978`) - Bug in :meth:`DataFrame.__setitem__` raising a ``TypeError`` when using a ``str`` subclass as the column name with a :class:`DatetimeIndex` (:issue:`37366`) - Bug in :meth:`PeriodIndex.get_loc` failing to raise a ``KeyError`` when given a :class:`Period` with a mismatched ``freq`` (:issue:`41670`) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 91b9bdd564676..2edad9f6626bb 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -158,6 +158,7 @@ from pandas.core.indexers import check_key_length from pandas.core.indexes import base as ibase from pandas.core.indexes.api import ( + CategoricalIndex, DatetimeIndex, Index, PeriodIndex, @@ -3553,6 +3554,11 @@ def _get_value(self, index, col, takeable: bool = False) -> Scalar: Returns ------- scalar + + Notes + ----- + Assumes that index and columns both have ax._index_as_unique; + caller is responsible for checking. """ if takeable: series = self._ixs(col, axis=1) @@ -3561,20 +3567,21 @@ def _get_value(self, index, col, takeable: bool = False) -> Scalar: series = self._get_item_cache(col) engine = self.index._engine + if isinstance(self.index, CategoricalIndex): + # Trying to use the engine fastpath may give incorrect results + # if our categories are integers that dont match our codes + col = self.columns.get_loc(col) + index = self.index.get_loc(index) + return self._get_value(index, col, takeable=True) + try: loc = engine.get_loc(index) return series._values[loc] - except KeyError: - # GH 20629 - if self.index.nlevels > 1: - # partial indexing forbidden - raise - - # we cannot handle direct indexing - # use positional - col = self.columns.get_loc(col) - index = self.index.get_loc(index) - return self._get_value(index, col, takeable=True) + except AttributeError: + # IntervalTree has no get_loc + col = self.columns.get_loc(col) + index = self.index.get_loc(index) + return self._get_value(index, col, takeable=True) def __setitem__(self, key, value): key = com.apply_if_callable(key, self) diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index 3707e141bc447..d62dcdba92bc7 100644 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -916,8 +916,7 @@ def __getitem__(self, key): key = tuple(list(x) if is_iterator(x) else x for x in key) key = tuple(com.apply_if_callable(x, self.obj) for x in key) if self._is_scalar_access(key): - with suppress(KeyError, IndexError, AttributeError): - # AttributeError for IntervalTree get_value + with suppress(KeyError, IndexError): return self.obj._get_value(*key, takeable=self._takeable) return self._getitem_tuple(key) else: @@ -1004,7 +1003,7 @@ def _is_scalar_access(self, key: tuple) -> bool: # should not be considered scalar return False - if not ax.is_unique: + if not ax._index_as_unique: return False return True diff --git a/pandas/tests/indexing/test_at.py b/pandas/tests/indexing/test_at.py index 77cfb94bf4629..23d2bee612243 100644 --- a/pandas/tests/indexing/test_at.py +++ b/pandas/tests/indexing/test_at.py @@ -8,6 +8,7 @@ from pandas import ( CategoricalDtype, + CategoricalIndex, DataFrame, Series, Timestamp, @@ -141,3 +142,16 @@ def test_at_getitem_mixed_index_no_fallback(self): ser.at[0] with pytest.raises(KeyError, match="^4$"): ser.at[4] + + def test_at_categorical_integers(self): + # CategoricalIndex with integer categories that don't happen to match + # the Categorical's codes + ci = CategoricalIndex([3, 4]) + + arr = np.arange(4).reshape(2, 2) + frame = DataFrame(arr, index=ci) + + for df in [frame, frame.T]: + for key in [0, 1]: + with pytest.raises(KeyError, match=str(key)): + df.at[key, key]