Skip to content

Commit acaff40

Browse files
committed
BUG: Fix CategoricalIndex.__contains__ with non-hashable, closes pandas-dev#21729
1 parent 65e123c commit acaff40

File tree

6 files changed

+30
-8
lines changed

6 files changed

+30
-8
lines changed

doc/source/whatsnew/v0.25.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,7 @@ Indexing
10491049
- Bug which produced ``AttributeError`` on partial matching :class:`Timestamp` in a :class:`MultiIndex` (:issue:`26944`)
10501050
- Bug in :class:`Categorical` and :class:`CategoricalIndex` with :class:`Interval` values when using the ``in`` operator (``__contains``) with objects that are not comparable to the values in the ``Interval`` (:issue:`23705`)
10511051
- Bug in :meth:`DataFrame.loc` and :meth:`DataFrame.iloc` on a :class:`DataFrame` with a single timezone-aware datetime64[ns] column incorrectly returning a scalar instead of a :class:`Series` (:issue:`27110`)
1052+
- Bug in :class:`CategoricalIndex` and :class:`Categorical` incorrectly raising ``ValueError`` instead of ``TypeError`` when a list is passed using the ``in`` operator (``__contains__``) (:issue:`21729`)
10521053
-
10531054
10541055
Missing

pandas/core/arrays/categorical.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2020,7 +2020,7 @@ def __contains__(self, key):
20202020
Returns True if `key` is in this Categorical.
20212021
"""
20222022
# if key is a NaN, check if any NaN is in self.
2023-
if isna(key):
2023+
if is_scalar(key) and isna(key):
20242024
return self.isna().any()
20252025

20262026
return contains(self, key, container=self._codes)

pandas/core/frame.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
ABCMultiIndex,
7878
ABCSeries,
7979
)
80+
from pandas.core.dtypes.inference import is_hashable
8081
from pandas.core.dtypes.missing import isna, notna
8182

8283
from pandas.core import algorithms, common as com, nanops, ops
@@ -2954,16 +2955,12 @@ def __getitem__(self, key):
29542955
key = lib.item_from_zerodim(key)
29552956
key = com.apply_if_callable(key, self)
29562957

2957-
# shortcut if the key is in columns
2958-
try:
2958+
if is_hashable(key):
2959+
# shortcut if the key is in columns
29592960
if self.columns.is_unique and key in self.columns:
29602961
if self.columns.nlevels > 1:
29612962
return self._getitem_multilevel(key)
29622963
return self._get_item_cache(key)
2963-
except (TypeError, ValueError):
2964-
# The TypeError correctly catches non hashable "key" (e.g. list)
2965-
# The ValueError can be removed once GH #21729 is fixed
2966-
pass
29672964

29682965
# Do we have a slicer (on rows)?
29692966
indexer = convert_to_index_sliceable(self, key)

pandas/core/indexes/category.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def _reverse_indexer(self):
407407
@Appender(_index_shared_docs["contains"] % _index_doc_kwargs)
408408
def __contains__(self, key):
409409
# if key is a NaN, check if any NaN is in self.
410-
if isna(key):
410+
if is_scalar(key) and isna(key):
411411
return self.hasnans
412412

413413
return contains(self, key, container=self._engine)

pandas/tests/arrays/categorical/test_operators.py

+12
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,15 @@ def test_contains_interval(self, item, expected):
417417
cat = Categorical(pd.IntervalIndex.from_breaks(range(3)))
418418
result = item in cat
419419
assert result is expected
420+
421+
def test_contains_list(self):
422+
# GH#21729
423+
cat = Categorical([1, 2, 3])
424+
425+
assert "a" not in cat
426+
427+
with pytest.raises(TypeError, match="unhashable type"):
428+
["a"] in cat
429+
430+
with pytest.raises(TypeError, match="unhashable type"):
431+
["a", "b"] in cat

pandas/tests/indexes/test_category.py

+12
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,18 @@ def test_contains_interval(self, item, expected):
276276
result = item in ci
277277
assert result is expected
278278

279+
def test_contains_list(self):
280+
# GH#21729
281+
idx = pd.CategoricalIndex([1, 2, 3])
282+
283+
assert "a" not in idx
284+
285+
with pytest.raises(TypeError, match="unhashable type"):
286+
["a"] in idx
287+
288+
with pytest.raises(TypeError, match="unhashable type"):
289+
["a", "b"] in idx
290+
279291
def test_map(self):
280292
ci = pd.CategoricalIndex(list("ABABC"), categories=list("CBA"), ordered=True)
281293
result = ci.map(lambda x: x.lower())

0 commit comments

Comments
 (0)