diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index bc5229d4b4296..f59155c595af4 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -518,6 +518,8 @@ Categorical - :meth:`Categorical.fillna` will always return a copy, will validate a passed fill value regardless of whether there are any NAs to fill, and will disallow a ``NaT`` as a fill value for numeric categories (:issue:`36530`) - Bug in :meth:`Categorical.__setitem__` that incorrectly raised when trying to set a tuple value (:issue:`20439`) - Bug in :meth:`CategoricalIndex.equals` incorrectly casting non-category entries to ``np.nan`` (:issue:`37667`) +- Bug in :meth:`CatgoricalIndex.where` incorrectly setting non-category entries to ``np.nan`` instead of raising ``TypeError`` (:issue:`37977`) +- Datetimelike ^^^^^^^^^^^^ diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index b40531bd42af8..5cc6525dc3c9b 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -321,3 +321,22 @@ def putmask(self, mask, value): value = self._validate_setitem_value(value) np.putmask(self._ndarray, mask, value) + + def where(self, mask, value): + """ + Analogue to np.where(mask, self, value) + + Parameters + ---------- + mask : np.ndarray[bool] + value : scalar or listlike + + Raises + ------ + TypeError + If value cannot be cast to self.dtype. + """ + value = self._validate_setitem_value(value) + + res_values = np.where(mask, self._ndarray, value) + return self._from_backing_data(res_values) diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 413c8f6b45275..e2507aeaeb652 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -403,18 +403,6 @@ def _to_safe_for_reshape(self): """ convert to object if we are a categorical """ return self.astype("object") - @doc(Index.where) - def where(self, cond, other=None): - # TODO: Investigate an alternative implementation with - # 1. copy the underlying Categorical - # 2. setitem with `cond` and `other` - # 3. Rebuild CategoricalIndex. - if other is None: - other = self._na_value - values = np.where(cond, self._values, other) - cat = Categorical(values, dtype=self.dtype) - return type(self)._simple_new(cat, name=self.name) - def reindex(self, target, method=None, level=None, limit=None, tolerance=None): """ Create index with target's values (move/add/delete values as necessary) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index bca6661f54900..40e27709df841 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -529,14 +529,6 @@ def isin(self, values, level=None): return algorithms.isin(self.asi8, values.asi8) - @Appender(Index.where.__doc__) - def where(self, cond, other=None): - other = self._data._validate_setitem_value(other) - - result = np.where(cond, self._data._ndarray, other) - arr = self._data._from_backing_data(result) - return type(self)._simple_new(arr, name=self.name) - def shift(self, periods=1, freq=None): """ Shift index by desired number of time frequency increments. diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index 0aa4b7732c048..6c35b882b5d67 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -378,6 +378,11 @@ def insert(self, loc: int, item): new_arr = arr._from_backing_data(new_vals) return type(self)._simple_new(new_arr, name=self.name) + @doc(Index.where) + def where(self, cond, other=None): + res_values = self._data.where(cond, other) + return type(self)._simple_new(res_values, name=self.name) + def putmask(self, mask, value): res_values = self._data.copy() try: diff --git a/pandas/tests/indexes/categorical/test_indexing.py b/pandas/tests/indexes/categorical/test_indexing.py index cf9360821d37f..617ffdb48b3b7 100644 --- a/pandas/tests/indexes/categorical/test_indexing.py +++ b/pandas/tests/indexes/categorical/test_indexing.py @@ -290,6 +290,18 @@ def test_where(self, klass): result = i.where(klass(cond)) tm.assert_index_equal(result, expected) + def test_where_non_categories(self): + ci = CategoricalIndex(["a", "b", "c", "d"]) + mask = np.array([True, False, True, False]) + + msg = "Cannot setitem on a Categorical with a new category" + with pytest.raises(ValueError, match=msg): + ci.where(mask, 2) + + with pytest.raises(ValueError, match=msg): + # Test the Categorical method directly + ci._data.where(mask, 2) + class TestContains: def test_contains(self):