diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index dc05745c8c0e5..873c7aabde785 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -883,6 +883,7 @@ Performance improvements - Performance improvement in :func:`merge` and :meth:`DataFrame.join` when joining on a sorted :class:`MultiIndex` (:issue:`48504`) - Performance improvement in :func:`to_datetime` when parsing strings with timezone offsets (:issue:`50107`) - Performance improvement in :meth:`DataFrame.loc` and :meth:`Series.loc` for tuple-based indexing of a :class:`MultiIndex` (:issue:`48384`) +- Performance improvement for :meth:`Series.replace` with categorical dtype (:issue:`49404`) - Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`) - Performance improvement for :func:`concat` with extension array backed indexes (:issue:`49128`, :issue:`49178`) - Reduce memory usage of :meth:`DataFrame.to_pickle`/:meth:`Series.to_pickle` when using BZ2 or LZMA (:issue:`49068`) @@ -927,6 +928,8 @@ Bug fixes Categorical ^^^^^^^^^^^ - Bug in :meth:`Categorical.set_categories` losing dtype information (:issue:`48812`) +- Bug in :meth:`Series.replace` with categorical dtype when ``to_replace`` values overlap with new values (:issue:`49404`) +- Bug in :meth:`Series.replace` with categorical dtype losing nullable dtypes of underlying categories (:issue:`49404`) - Bug in :meth:`DataFrame.groupby` and :meth:`Series.groupby` would reorder categories when used as a grouper (:issue:`48749`) - Bug in :class:`Categorical` constructor when constructing from a :class:`Categorical` object and ``dtype="category"`` losing ordered-ness (:issue:`49309`) - diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 14f334d72dbb1..5b61695410474 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -1137,14 +1137,9 @@ def remove_categories(self, removals): if not is_list_like(removals): removals = [removals] - removal_set = set(removals) - not_included = removal_set - set(self.dtype.categories) - new_categories = [c for c in self.dtype.categories if c not in removal_set] - - # GH 10156 - if any(isna(removals)): - not_included = {x for x in not_included if notna(x)} - new_categories = [x for x in new_categories if notna(x)] + removals = {x for x in set(removals) if notna(x)} + new_categories = self.dtype.categories.difference(removals) + not_included = removals.difference(self.dtype.categories) if len(not_included) != 0: raise ValueError(f"removals must all be in old categories: {not_included}") @@ -2273,42 +2268,28 @@ def isin(self, values) -> npt.NDArray[np.bool_]: return algorithms.isin(self.codes, code_values) def _replace(self, *, to_replace, value, inplace: bool = False): + from pandas import Index + inplace = validate_bool_kwarg(inplace, "inplace") cat = self if inplace else self.copy() - # other cases, like if both to_replace and value are list-like or if - # to_replace is a dict, are handled separately in NDFrame - if not is_list_like(to_replace): - to_replace = [to_replace] - - categories = cat.categories.tolist() - removals = set() - for replace_value in to_replace: - if value == replace_value: - continue - if replace_value not in cat.categories: - continue - if isna(value): - removals.add(replace_value) - continue - - index = categories.index(replace_value) - - if value in cat.categories: - value_index = categories.index(value) - cat._codes[cat._codes == index] = value_index - removals.add(replace_value) - else: - categories[index] = value - cat._set_categories(categories) + mask = isna(np.asarray(value)) + if mask.any(): + removals = np.asarray(to_replace)[mask] + removals = cat.categories[cat.categories.isin(removals)] + new_cat = cat.remove_categories(removals) + NDArrayBacked.__init__(cat, new_cat.codes, new_cat.dtype) - if len(removals): - new_categories = [c for c in categories if c not in removals] - new_dtype = CategoricalDtype(new_categories, ordered=self.dtype.ordered) - codes = recode_for_categories( - cat.codes, cat.categories, new_dtype.categories - ) - NDArrayBacked.__init__(cat, codes, new_dtype) + ser = cat.categories.to_series() + ser = ser.replace(to_replace=to_replace, value=value) + + all_values = Index(ser) + new_categories = Index(ser.drop_duplicates(keep="first")) + new_codes = recode_for_categories( + cat._codes, all_values, new_categories, copy=False + ) + new_dtype = CategoricalDtype(new_categories, ordered=self.dtype.ordered) + NDArrayBacked.__init__(cat, new_codes, new_dtype) if not inplace: return cat diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 00ab9d02cee00..8fb6a18ca137a 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -536,12 +536,10 @@ def replace( if isinstance(values, Categorical): # TODO: avoid special-casing + # GH49404 blk = self if inplace else self.copy() - # error: Item "ExtensionArray" of "Union[ndarray[Any, Any], - # ExtensionArray]" has no attribute "_replace" - blk.values._replace( # type: ignore[union-attr] - to_replace=to_replace, value=value, inplace=True - ) + values = cast(Categorical, blk.values) + values._replace(to_replace=to_replace, value=value, inplace=True) return [blk] if not self._can_hold_element(to_replace): @@ -651,6 +649,14 @@ def replace_list( """ values = self.values + if isinstance(values, Categorical): + # TODO: avoid special-casing + # GH49404 + blk = self if inplace else self.copy() + values = cast(Categorical, blk.values) + values._replace(to_replace=src_list, value=dest_list, inplace=True) + return [blk] + # Exclude anything that we know we won't contain pairs = [ (x, y) for x, y in zip(src_list, dest_list) if self._can_hold_element(x) diff --git a/pandas/tests/arrays/categorical/test_replace.py b/pandas/tests/arrays/categorical/test_replace.py index a3ba420c84a17..c25f1d9c9feac 100644 --- a/pandas/tests/arrays/categorical/test_replace.py +++ b/pandas/tests/arrays/categorical/test_replace.py @@ -21,6 +21,11 @@ ((5, 6), 2, [1, 2, 3], False), ([1], [2], [2, 2, 3], False), ([1, 4], [5, 2], [5, 2, 3], False), + # GH49404: overlap between to_replace and value + ([1, 2, 3], [2, 3, 4], [2, 3, 4], False), + # GH50872, GH46884: replace with null + (1, None, [None, 2, 3], False), + (1, pd.NA, [None, 2, 3], False), # check_categorical sorts categories, which crashes on mixed dtypes (3, "4", [1, 2, "4"], False), ([1, 2, "3"], "5", ["5", "5", 3], True), @@ -65,3 +70,11 @@ def test_replace_categorical(to_replace, value, result, expected_error_msg): pd.Series(cat).replace(to_replace, value, inplace=True) tm.assert_categorical_equal(cat, expected) + + +def test_replace_categorical_ea_dtype(): + # GH49404 + cat = Categorical(pd.array(["a", "b"], dtype="string")) + result = pd.Series(cat).replace(["a", "b"], ["c", pd.NA])._values + expected = Categorical(pd.array(["c", pd.NA], dtype="string")) + tm.assert_categorical_equal(result, expected)