Skip to content

Commit 1e21130

Browse files
authored
BUG: CategoricalIndex.where nulling out non-categories (#37977)
* BUG: CategoricalIndex.where nulling out non-categories * whatsnew
1 parent 3d5e65d commit 1e21130

File tree

6 files changed

+38
-20
lines changed

6 files changed

+38
-20
lines changed

doc/source/whatsnew/v1.2.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,8 @@ Categorical
518518
- :meth:`Categorical.fillna` will always return a copy, will validate a passed fill value regardless of whether there are any NAs to fill, and will disallow a ``NaT`` as a fill value for numeric categories (:issue:`36530`)
519519
- Bug in :meth:`Categorical.__setitem__` that incorrectly raised when trying to set a tuple value (:issue:`20439`)
520520
- Bug in :meth:`CategoricalIndex.equals` incorrectly casting non-category entries to ``np.nan`` (:issue:`37667`)
521+
- Bug in :meth:`CatgoricalIndex.where` incorrectly setting non-category entries to ``np.nan`` instead of raising ``TypeError`` (:issue:`37977`)
522+
-
521523

522524
Datetimelike
523525
^^^^^^^^^^^^

pandas/core/arrays/_mixins.py

+19
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,22 @@ def putmask(self, mask, value):
321321
value = self._validate_setitem_value(value)
322322

323323
np.putmask(self._ndarray, mask, value)
324+
325+
def where(self, mask, value):
326+
"""
327+
Analogue to np.where(mask, self, value)
328+
329+
Parameters
330+
----------
331+
mask : np.ndarray[bool]
332+
value : scalar or listlike
333+
334+
Raises
335+
------
336+
TypeError
337+
If value cannot be cast to self.dtype.
338+
"""
339+
value = self._validate_setitem_value(value)
340+
341+
res_values = np.where(mask, self._ndarray, value)
342+
return self._from_backing_data(res_values)

pandas/core/indexes/category.py

-12
Original file line numberDiff line numberDiff line change
@@ -403,18 +403,6 @@ def _to_safe_for_reshape(self):
403403
""" convert to object if we are a categorical """
404404
return self.astype("object")
405405

406-
@doc(Index.where)
407-
def where(self, cond, other=None):
408-
# TODO: Investigate an alternative implementation with
409-
# 1. copy the underlying Categorical
410-
# 2. setitem with `cond` and `other`
411-
# 3. Rebuild CategoricalIndex.
412-
if other is None:
413-
other = self._na_value
414-
values = np.where(cond, self._values, other)
415-
cat = Categorical(values, dtype=self.dtype)
416-
return type(self)._simple_new(cat, name=self.name)
417-
418406
def reindex(self, target, method=None, level=None, limit=None, tolerance=None):
419407
"""
420408
Create index with target's values (move/add/delete values as necessary)

pandas/core/indexes/datetimelike.py

-8
Original file line numberDiff line numberDiff line change
@@ -529,14 +529,6 @@ def isin(self, values, level=None):
529529

530530
return algorithms.isin(self.asi8, values.asi8)
531531

532-
@Appender(Index.where.__doc__)
533-
def where(self, cond, other=None):
534-
other = self._data._validate_setitem_value(other)
535-
536-
result = np.where(cond, self._data._ndarray, other)
537-
arr = self._data._from_backing_data(result)
538-
return type(self)._simple_new(arr, name=self.name)
539-
540532
def shift(self, periods=1, freq=None):
541533
"""
542534
Shift index by desired number of time frequency increments.

pandas/core/indexes/extension.py

+5
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,11 @@ def insert(self, loc: int, item):
378378
new_arr = arr._from_backing_data(new_vals)
379379
return type(self)._simple_new(new_arr, name=self.name)
380380

381+
@doc(Index.where)
382+
def where(self, cond, other=None):
383+
res_values = self._data.where(cond, other)
384+
return type(self)._simple_new(res_values, name=self.name)
385+
381386
def putmask(self, mask, value):
382387
res_values = self._data.copy()
383388
try:

pandas/tests/indexes/categorical/test_indexing.py

+12
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,18 @@ def test_where(self, klass):
290290
result = i.where(klass(cond))
291291
tm.assert_index_equal(result, expected)
292292

293+
def test_where_non_categories(self):
294+
ci = CategoricalIndex(["a", "b", "c", "d"])
295+
mask = np.array([True, False, True, False])
296+
297+
msg = "Cannot setitem on a Categorical with a new category"
298+
with pytest.raises(ValueError, match=msg):
299+
ci.where(mask, 2)
300+
301+
with pytest.raises(ValueError, match=msg):
302+
# Test the Categorical method directly
303+
ci._data.where(mask, 2)
304+
293305

294306
class TestContains:
295307
def test_contains(self):

0 commit comments

Comments
 (0)