Skip to content

Commit a063af0

Browse files
authored
BUG/PERF: Series.replace with dtype="category" (#49404)
* refactor Categorical._replace * gh refs * cleanup
1 parent 9097263 commit a063af0

File tree

4 files changed

+38
-37
lines changed

4 files changed

+38
-37
lines changed

doc/source/whatsnew/v2.0.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,7 @@ Performance improvements
844844
- Performance improvement in :func:`merge` and :meth:`DataFrame.join` when joining on a sorted :class:`MultiIndex` (:issue:`48504`)
845845
- Performance improvement in :func:`to_datetime` when parsing strings with timezone offsets (:issue:`50107`)
846846
- Performance improvement in :meth:`DataFrame.loc` and :meth:`Series.loc` for tuple-based indexing of a :class:`MultiIndex` (:issue:`48384`)
847+
- Performance improvement for :meth:`Series.replace` with categorical dtype (:issue:`49404`)
847848
- Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`)
848849
- Performance improvement for :func:`concat` with extension array backed indexes (:issue:`49128`, :issue:`49178`)
849850
- Reduce memory usage of :meth:`DataFrame.to_pickle`/:meth:`Series.to_pickle` when using BZ2 or LZMA (:issue:`49068`)
@@ -886,6 +887,8 @@ Bug fixes
886887
Categorical
887888
^^^^^^^^^^^
888889
- Bug in :meth:`Categorical.set_categories` losing dtype information (:issue:`48812`)
890+
- Bug in :meth:`Series.replace` with categorical dtype when ``to_replace`` values overlap with new values (:issue:`49404`)
891+
- Bug in :meth:`Series.replace` with categorical dtype losing nullable dtypes of underlying categories (:issue:`49404`)
889892
- Bug in :meth:`DataFrame.groupby` and :meth:`Series.groupby` would reorder categories when used as a grouper (:issue:`48749`)
890893
- Bug in :class:`Categorical` constructor when constructing from a :class:`Categorical` object and ``dtype="category"`` losing ordered-ness (:issue:`49309`)
891894
-

pandas/core/arrays/categorical.py

+14-32
Original file line numberDiff line numberDiff line change
@@ -2273,42 +2273,24 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
22732273
return algorithms.isin(self.codes, code_values)
22742274

22752275
def _replace(self, *, to_replace, value, inplace: bool = False):
2276+
from pandas import (
2277+
Index,
2278+
Series,
2279+
)
2280+
22762281
inplace = validate_bool_kwarg(inplace, "inplace")
22772282
cat = self if inplace else self.copy()
22782283

2279-
# other cases, like if both to_replace and value are list-like or if
2280-
# to_replace is a dict, are handled separately in NDFrame
2281-
if not is_list_like(to_replace):
2282-
to_replace = [to_replace]
2283-
2284-
categories = cat.categories.tolist()
2285-
removals = set()
2286-
for replace_value in to_replace:
2287-
if value == replace_value:
2288-
continue
2289-
if replace_value not in cat.categories:
2290-
continue
2291-
if isna(value):
2292-
removals.add(replace_value)
2293-
continue
2294-
2295-
index = categories.index(replace_value)
2296-
2297-
if value in cat.categories:
2298-
value_index = categories.index(value)
2299-
cat._codes[cat._codes == index] = value_index
2300-
removals.add(replace_value)
2301-
else:
2302-
categories[index] = value
2303-
cat._set_categories(categories)
2284+
ser = Series(cat.categories, copy=True)
2285+
ser = ser.replace(to_replace=to_replace, value=value)
23042286

2305-
if len(removals):
2306-
new_categories = [c for c in categories if c not in removals]
2307-
new_dtype = CategoricalDtype(new_categories, ordered=self.dtype.ordered)
2308-
codes = recode_for_categories(
2309-
cat.codes, cat.categories, new_dtype.categories
2310-
)
2311-
NDArrayBacked.__init__(cat, codes, new_dtype)
2287+
all_values = Index(ser)
2288+
new_categories = Index(ser.dropna().drop_duplicates(keep="first"))
2289+
new_codes = recode_for_categories(
2290+
cat._codes, all_values, new_categories, copy=False
2291+
)
2292+
new_dtype = CategoricalDtype(new_categories, ordered=self.dtype.ordered)
2293+
NDArrayBacked.__init__(cat, new_codes, new_dtype)
23122294

23132295
if not inplace:
23142296
return cat

pandas/core/internals/blocks.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -536,12 +536,10 @@ def replace(
536536

537537
if isinstance(values, Categorical):
538538
# TODO: avoid special-casing
539+
# GH49404
539540
blk = self if inplace else self.copy()
540-
# error: Item "ExtensionArray" of "Union[ndarray[Any, Any],
541-
# ExtensionArray]" has no attribute "_replace"
542-
blk.values._replace( # type: ignore[union-attr]
543-
to_replace=to_replace, value=value, inplace=True
544-
)
541+
values = cast(Categorical, blk.values)
542+
values._replace(to_replace=to_replace, value=value, inplace=True)
545543
return [blk]
546544

547545
if not self._can_hold_element(to_replace):
@@ -651,6 +649,14 @@ def replace_list(
651649
"""
652650
values = self.values
653651

652+
if isinstance(values, Categorical):
653+
# TODO: avoid special-casing
654+
# GH49404
655+
blk = self if inplace else self.copy()
656+
values = cast(Categorical, blk.values)
657+
values._replace(to_replace=src_list, value=dest_list, inplace=True)
658+
return [blk]
659+
654660
# Exclude anything that we know we won't contain
655661
pairs = [
656662
(x, y) for x, y in zip(src_list, dest_list) if self._can_hold_element(x)

pandas/tests/arrays/categorical/test_replace.py

+10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
((5, 6), 2, [1, 2, 3], False),
2222
([1], [2], [2, 2, 3], False),
2323
([1, 4], [5, 2], [5, 2, 3], False),
24+
# GH49404
25+
([1, 2, 3], [2, 3, 4], [2, 3, 4], False),
2426
# check_categorical sorts categories, which crashes on mixed dtypes
2527
(3, "4", [1, 2, "4"], False),
2628
([1, 2, "3"], "5", ["5", "5", 3], True),
@@ -65,3 +67,11 @@ def test_replace_categorical(to_replace, value, result, expected_error_msg):
6567

6668
pd.Series(cat).replace(to_replace, value, inplace=True)
6769
tm.assert_categorical_equal(cat, expected)
70+
71+
72+
def test_replace_categorical_ea_dtype():
73+
# GH49404
74+
cat = Categorical(pd.array(["a", "b"], dtype="string"))
75+
result = pd.Series(cat).replace(["a", "b"], ["c", pd.NA])._values
76+
expected = Categorical(pd.array(["c", pd.NA], dtype="string"))
77+
tm.assert_categorical_equal(result, expected)

0 commit comments

Comments
 (0)