Skip to content

BUG: Fix CategoricalIndex.__contains__ with non-hashable, closes #21729 #27284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.25.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ Indexing
- Bug which produced ``AttributeError`` on partial matching :class:`Timestamp` in a :class:`MultiIndex` (:issue:`26944`)
- Bug in :class:`Categorical` and :class:`CategoricalIndex` with :class:`Interval` values when using the ``in`` operator (``__contains``) with objects that are not comparable to the values in the ``Interval`` (:issue:`23705`)
- Bug in :meth:`DataFrame.loc` and :meth:`DataFrame.iloc` on a :class:`DataFrame` with a single timezone-aware datetime64[ns] column incorrectly returning a scalar instead of a :class:`Series` (:issue:`27110`)
- Bug in :class:`CategoricalIndex` and :class:`Categorical` incorrectly raising ``ValueError`` instead of ``TypeError`` when a list is passed using the ``in`` operator (``__contains__``) (:issue:`21729`)
-

Missing
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,7 +2020,7 @@ def __contains__(self, key):
Returns True if `key` is in this Categorical.
"""
# if key is a NaN, check if any NaN is in self.
if isna(key):
if is_scalar(key) and isna(key):
return self.isna().any()

return contains(self, key, container=self._codes)
Expand Down
9 changes: 3 additions & 6 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
is_extension_array_dtype,
is_extension_type,
is_float_dtype,
is_hashable,
is_integer,
is_integer_dtype,
is_iterator,
Expand Down Expand Up @@ -2954,16 +2955,12 @@ def __getitem__(self, key):
key = lib.item_from_zerodim(key)
key = com.apply_if_callable(key, self)

# shortcut if the key is in columns
try:
if is_hashable(key):
# shortcut if the key is in columns
if self.columns.is_unique and key in self.columns:
if self.columns.nlevels > 1:
return self._getitem_multilevel(key)
return self._get_item_cache(key)
except (TypeError, ValueError):
# The TypeError correctly catches non hashable "key" (e.g. list)
# The ValueError can be removed once GH #21729 is fixed
pass

# Do we have a slicer (on rows)?
indexer = convert_to_index_sliceable(self, key)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _reverse_indexer(self):
@Appender(_index_shared_docs["contains"] % _index_doc_kwargs)
def __contains__(self, key):
# if key is a NaN, check if any NaN is in self.
if isna(key):
if is_scalar(key) and isna(key):
return self.hasnans

return contains(self, key, container=self._engine)
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/arrays/categorical/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,15 @@ def test_contains_interval(self, item, expected):
cat = Categorical(pd.IntervalIndex.from_breaks(range(3)))
result = item in cat
assert result is expected

def test_contains_list(self):
# GH#21729
cat = Categorical([1, 2, 3])

assert "a" not in cat

with pytest.raises(TypeError, match="unhashable type"):
["a"] in cat

with pytest.raises(TypeError, match="unhashable type"):
["a", "b"] in cat
12 changes: 12 additions & 0 deletions pandas/tests/indexes/test_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,18 @@ def test_contains_interval(self, item, expected):
result = item in ci
assert result is expected

def test_contains_list(self):
# GH#21729
idx = pd.CategoricalIndex([1, 2, 3])

assert "a" not in idx

with pytest.raises(TypeError, match="unhashable type"):
["a"] in idx

with pytest.raises(TypeError, match="unhashable type"):
["a", "b"] in idx

def test_map(self):
ci = pd.CategoricalIndex(list("ABABC"), categories=list("CBA"), ordered=True)
result = ci.map(lambda x: x.lower())
Expand Down