From c1846ab7aff4fc72122245b1fb34a4f7582dc30f Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 16 Sep 2020 09:32:45 -0700 Subject: [PATCH] REF: implement putmask for CI/DTI/TDI/PI --- pandas/core/arrays/categorical.py | 5 +++++ pandas/core/indexes/base.py | 3 --- pandas/core/indexes/category.py | 11 +++++++++++ pandas/core/indexes/datetimelike.py | 13 ++++++++++++- pandas/tests/indexes/common.py | 7 ++++--- 5 files changed, 32 insertions(+), 7 deletions(-) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 25073282ec0f6..418140c82da08 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -1171,6 +1171,11 @@ def map(self, mapper): # ------------------------------------------------------------- # Validators; ideally these can be de-duplicated + def _validate_where_value(self, value): + if is_scalar(value): + return self._validate_fill_value(value) + return self._validate_listlike(value) + def _validate_insert_value(self, value) -> int: code = self.categories.get_indexer([value]) if (code == -1) and not (is_scalar(value) and isna(value)): diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 15944565cb254..a2f11160b2fdc 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4232,9 +4232,6 @@ def putmask(self, mask, value): try: converted = self._validate_fill_value(value) np.putmask(values, mask, converted) - if is_period_dtype(self.dtype): - # .values cast to object, so we need to cast back - values = type(self)(values)._data return self._shallow_copy(values) except (ValueError, TypeError) as err: if is_object_dtype(self): diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 9e4714060e23e..d73b36eff69f3 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -422,6 +422,17 @@ def where(self, cond, other=None): cat = Categorical(values, dtype=self.dtype) return type(self)._simple_new(cat, name=self.name) + def putmask(self, mask, value): + try: + code_value = self._data._validate_where_value(value) + except (TypeError, ValueError): + return self.astype(object).putmask(mask, value) + + codes = self._data._ndarray.copy() + np.putmask(codes, mask, code_value) + cat = self._data._from_backing_data(codes) + return type(self)._simple_new(cat, name=self.name) + def reindex(self, target, method=None, level=None, limit=None, tolerance=None): """ Create index with target's values (move/add/delete values as necessary) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 5ba5732c710f7..0226857f3eab7 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -484,7 +484,18 @@ def where(self, cond, other=None): raise TypeError(f"Where requires matching dtype, not {oth}") from err result = np.where(cond, values, other).astype("i8") - arr = type(self._data)._simple_new(result, dtype=self.dtype) + arr = self._data._from_backing_data(result) + return type(self)._simple_new(arr, name=self.name) + + def putmask(self, mask, value): + try: + value = self._data._validate_where_value(value) + except (TypeError, ValueError): + return self.astype(object).putmask(mask, value) + + result = self._data._ndarray.copy() + np.putmask(result, mask, value) + arr = self._data._from_backing_data(result) return type(self)._simple_new(arr, name=self.name) def _summary(self, name=None) -> str: diff --git a/pandas/tests/indexes/common.py b/pandas/tests/indexes/common.py index 11dc232af8de4..0e9e5c0b32d18 100644 --- a/pandas/tests/indexes/common.py +++ b/pandas/tests/indexes/common.py @@ -846,16 +846,17 @@ def test_map_str(self): def test_putmask_with_wrong_mask(self): # GH18368 index = self.create_index() + fill = index[0] msg = "putmask: mask and data must be the same size" with pytest.raises(ValueError, match=msg): - index.putmask(np.ones(len(index) + 1, np.bool_), 1) + index.putmask(np.ones(len(index) + 1, np.bool_), fill) with pytest.raises(ValueError, match=msg): - index.putmask(np.ones(len(index) - 1, np.bool_), 1) + index.putmask(np.ones(len(index) - 1, np.bool_), fill) with pytest.raises(ValueError, match=msg): - index.putmask("foo", 1) + index.putmask("foo", fill) @pytest.mark.parametrize("copy", [True, False]) @pytest.mark.parametrize("name", [None, "foo"])