Skip to content

Commit 859e4eb

Browse files
authored
BUG/PERF: Series(category).replace (pandas-dev#50857)
* bug/perf: Series(category).replace * fixes * add test for replace with null * add test for GH46884
1 parent bf1d008 commit 859e4eb

File tree

4 files changed

+48
-45
lines changed

4 files changed

+48
-45
lines changed

doc/source/whatsnew/v2.0.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,7 @@ Performance improvements
902902
- Performance improvement in :func:`merge` and :meth:`DataFrame.join` when joining on a sorted :class:`MultiIndex` (:issue:`48504`)
903903
- Performance improvement in :func:`to_datetime` when parsing strings with timezone offsets (:issue:`50107`)
904904
- Performance improvement in :meth:`DataFrame.loc` and :meth:`Series.loc` for tuple-based indexing of a :class:`MultiIndex` (:issue:`48384`)
905+
- Performance improvement for :meth:`Series.replace` with categorical dtype (:issue:`49404`)
905906
- Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`)
906907
- Performance improvement for :func:`concat` with extension array backed indexes (:issue:`49128`, :issue:`49178`)
907908
- Reduce memory usage of :meth:`DataFrame.to_pickle`/:meth:`Series.to_pickle` when using BZ2 or LZMA (:issue:`49068`)
@@ -946,6 +947,8 @@ Bug fixes
946947
Categorical
947948
^^^^^^^^^^^
948949
- Bug in :meth:`Categorical.set_categories` losing dtype information (:issue:`48812`)
950+
- Bug in :meth:`Series.replace` with categorical dtype when ``to_replace`` values overlap with new values (:issue:`49404`)
951+
- Bug in :meth:`Series.replace` with categorical dtype losing nullable dtypes of underlying categories (:issue:`49404`)
949952
- Bug in :meth:`DataFrame.groupby` and :meth:`Series.groupby` would reorder categories when used as a grouper (:issue:`48749`)
950953
- Bug in :class:`Categorical` constructor when constructing from a :class:`Categorical` object and ``dtype="category"`` losing ordered-ness (:issue:`49309`)
951954
-

pandas/core/arrays/categorical.py

+21-40
Original file line numberDiff line numberDiff line change
@@ -1137,14 +1137,9 @@ def remove_categories(self, removals):
11371137
if not is_list_like(removals):
11381138
removals = [removals]
11391139

1140-
removal_set = set(removals)
1141-
not_included = removal_set - set(self.dtype.categories)
1142-
new_categories = [c for c in self.dtype.categories if c not in removal_set]
1143-
1144-
# GH 10156
1145-
if any(isna(removals)):
1146-
not_included = {x for x in not_included if notna(x)}
1147-
new_categories = [x for x in new_categories if notna(x)]
1140+
removals = {x for x in set(removals) if notna(x)}
1141+
new_categories = self.dtype.categories.difference(removals)
1142+
not_included = removals.difference(self.dtype.categories)
11481143

11491144
if len(not_included) != 0:
11501145
raise ValueError(f"removals must all be in old categories: {not_included}")
@@ -2273,42 +2268,28 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
22732268
return algorithms.isin(self.codes, code_values)
22742269

22752270
def _replace(self, *, to_replace, value, inplace: bool = False):
2271+
from pandas import Index
2272+
22762273
inplace = validate_bool_kwarg(inplace, "inplace")
22772274
cat = self if inplace else self.copy()
22782275

2279-
# other cases, like if both to_replace and value are list-like or if
2280-
# to_replace is a dict, are handled separately in NDFrame
2281-
if not is_list_like(to_replace):
2282-
to_replace = [to_replace]
2283-
2284-
categories = cat.categories.tolist()
2285-
removals = set()
2286-
for replace_value in to_replace:
2287-
if value == replace_value:
2288-
continue
2289-
if replace_value not in cat.categories:
2290-
continue
2291-
if isna(value):
2292-
removals.add(replace_value)
2293-
continue
2294-
2295-
index = categories.index(replace_value)
2296-
2297-
if value in cat.categories:
2298-
value_index = categories.index(value)
2299-
cat._codes[cat._codes == index] = value_index
2300-
removals.add(replace_value)
2301-
else:
2302-
categories[index] = value
2303-
cat._set_categories(categories)
2276+
mask = isna(np.asarray(value))
2277+
if mask.any():
2278+
removals = np.asarray(to_replace)[mask]
2279+
removals = cat.categories[cat.categories.isin(removals)]
2280+
new_cat = cat.remove_categories(removals)
2281+
NDArrayBacked.__init__(cat, new_cat.codes, new_cat.dtype)
23042282

2305-
if len(removals):
2306-
new_categories = [c for c in categories if c not in removals]
2307-
new_dtype = CategoricalDtype(new_categories, ordered=self.dtype.ordered)
2308-
codes = recode_for_categories(
2309-
cat.codes, cat.categories, new_dtype.categories
2310-
)
2311-
NDArrayBacked.__init__(cat, codes, new_dtype)
2283+
ser = cat.categories.to_series()
2284+
ser = ser.replace(to_replace=to_replace, value=value)
2285+
2286+
all_values = Index(ser)
2287+
new_categories = Index(ser.drop_duplicates(keep="first"))
2288+
new_codes = recode_for_categories(
2289+
cat._codes, all_values, new_categories, copy=False
2290+
)
2291+
new_dtype = CategoricalDtype(new_categories, ordered=self.dtype.ordered)
2292+
NDArrayBacked.__init__(cat, new_codes, new_dtype)
23122293

23132294
if not inplace:
23142295
return cat

pandas/core/internals/blocks.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -536,12 +536,10 @@ def replace(
536536

537537
if isinstance(values, Categorical):
538538
# TODO: avoid special-casing
539+
# GH49404
539540
blk = self if inplace else self.copy()
540-
# error: Item "ExtensionArray" of "Union[ndarray[Any, Any],
541-
# ExtensionArray]" has no attribute "_replace"
542-
blk.values._replace( # type: ignore[union-attr]
543-
to_replace=to_replace, value=value, inplace=True
544-
)
541+
values = cast(Categorical, blk.values)
542+
values._replace(to_replace=to_replace, value=value, inplace=True)
545543
return [blk]
546544

547545
if not self._can_hold_element(to_replace):
@@ -651,6 +649,14 @@ def replace_list(
651649
"""
652650
values = self.values
653651

652+
if isinstance(values, Categorical):
653+
# TODO: avoid special-casing
654+
# GH49404
655+
blk = self if inplace else self.copy()
656+
values = cast(Categorical, blk.values)
657+
values._replace(to_replace=src_list, value=dest_list, inplace=True)
658+
return [blk]
659+
654660
# Exclude anything that we know we won't contain
655661
pairs = [
656662
(x, y) for x, y in zip(src_list, dest_list) if self._can_hold_element(x)

pandas/tests/arrays/categorical/test_replace.py

+13
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
((5, 6), 2, [1, 2, 3], False),
2222
([1], [2], [2, 2, 3], False),
2323
([1, 4], [5, 2], [5, 2, 3], False),
24+
# GH49404: overlap between to_replace and value
25+
([1, 2, 3], [2, 3, 4], [2, 3, 4], False),
26+
# GH50872, GH46884: replace with null
27+
(1, None, [None, 2, 3], False),
28+
(1, pd.NA, [None, 2, 3], False),
2429
# check_categorical sorts categories, which crashes on mixed dtypes
2530
(3, "4", [1, 2, "4"], False),
2631
([1, 2, "3"], "5", ["5", "5", 3], True),
@@ -65,3 +70,11 @@ def test_replace_categorical(to_replace, value, result, expected_error_msg):
6570

6671
pd.Series(cat).replace(to_replace, value, inplace=True)
6772
tm.assert_categorical_equal(cat, expected)
73+
74+
75+
def test_replace_categorical_ea_dtype():
76+
# GH49404
77+
cat = Categorical(pd.array(["a", "b"], dtype="string"))
78+
result = pd.Series(cat).replace(["a", "b"], ["c", pd.NA])._values
79+
expected = Categorical(pd.array(["c", pd.NA], dtype="string"))
80+
tm.assert_categorical_equal(result, expected)

0 commit comments

Comments
 (0)