Skip to content

BUG: CategoricalIndex.where nulling out non-categories #37977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ Categorical
- :meth:`Categorical.fillna` will always return a copy, will validate a passed fill value regardless of whether there are any NAs to fill, and will disallow a ``NaT`` as a fill value for numeric categories (:issue:`36530`)
- Bug in :meth:`Categorical.__setitem__` that incorrectly raised when trying to set a tuple value (:issue:`20439`)
- Bug in :meth:`CategoricalIndex.equals` incorrectly casting non-category entries to ``np.nan`` (:issue:`37667`)
- Bug in :meth:`CatgoricalIndex.where` incorrectly setting non-category entries to ``np.nan`` instead of raising ``TypeError`` (:issue:`37977`)
-

Datetimelike
^^^^^^^^^^^^
Expand Down
19 changes: 19 additions & 0 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,22 @@ def putmask(self, mask, value):
value = self._validate_setitem_value(value)

np.putmask(self._ndarray, mask, value)

def where(self, mask, value):
"""
Analogue to np.where(mask, self, value)

Parameters
----------
mask : np.ndarray[bool]
value : scalar or listlike

Raises
------
TypeError
If value cannot be cast to self.dtype.
"""
value = self._validate_setitem_value(value)

res_values = np.where(mask, self._ndarray, value)
return self._from_backing_data(res_values)
12 changes: 0 additions & 12 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,18 +403,6 @@ def _to_safe_for_reshape(self):
""" convert to object if we are a categorical """
return self.astype("object")

@doc(Index.where)
def where(self, cond, other=None):
# TODO: Investigate an alternative implementation with
# 1. copy the underlying Categorical
# 2. setitem with `cond` and `other`
# 3. Rebuild CategoricalIndex.
if other is None:
other = self._na_value
values = np.where(cond, self._values, other)
cat = Categorical(values, dtype=self.dtype)
return type(self)._simple_new(cat, name=self.name)

def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
"""
Create index with target's values (move/add/delete values as necessary)
Expand Down
8 changes: 0 additions & 8 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,6 @@ def isin(self, values, level=None):

return algorithms.isin(self.asi8, values.asi8)

@Appender(Index.where.__doc__)
def where(self, cond, other=None):
other = self._data._validate_setitem_value(other)

result = np.where(cond, self._data._ndarray, other)
arr = self._data._from_backing_data(result)
return type(self)._simple_new(arr, name=self.name)

def shift(self, periods=1, freq=None):
"""
Shift index by desired number of time frequency increments.
Expand Down
5 changes: 5 additions & 0 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,11 @@ def insert(self, loc: int, item):
new_arr = arr._from_backing_data(new_vals)
return type(self)._simple_new(new_arr, name=self.name)

@doc(Index.where)
def where(self, cond, other=None):
res_values = self._data.where(cond, other)
return type(self)._simple_new(res_values, name=self.name)

def putmask(self, mask, value):
res_values = self._data.copy()
try:
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/indexes/categorical/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ def test_where(self, klass):
result = i.where(klass(cond))
tm.assert_index_equal(result, expected)

def test_where_non_categories(self):
ci = CategoricalIndex(["a", "b", "c", "d"])
mask = np.array([True, False, True, False])

msg = "Cannot setitem on a Categorical with a new category"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an odd message for this operation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah thats a side-effect of de-duplicating all the validators. the datetimelike one is nicer, so im planning on doing a pass to make the categorical one behave more like the datetimelike one. that'll be a big diff since it will involve changing a lot of tests, so planning to do that in a dedicated PR

with pytest.raises(ValueError, match=msg):
ci.where(mask, 2)

with pytest.raises(ValueError, match=msg):
# Test the Categorical method directly
ci._data.where(mask, 2)


class TestContains:
def test_contains(self):
Expand Down