Skip to content

Commit ea047aa

Browse files
JustinZhengBCmeeseeksmachine
authored andcommitted
Backport PR pandas-dev#31734: BUG: list-like to_replace on Categorical.replace is ignored or crash
1 parent 12a7d65 commit ea047aa

File tree

4 files changed

+81
-10
lines changed

4 files changed

+81
-10
lines changed

doc/source/whatsnew/v1.0.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Bug fixes
2929
**Categorical**
3030

3131
- Fixed bug where :meth:`Categorical.from_codes` improperly raised a ``ValueError`` when passed nullable integer codes. (:issue:`31779`)
32+
- Bug in :class:`Categorical` that would ignore or crash when calling :meth:`Series.replace` with a list-like ``to_replace`` (:issue:`31720`)
3233

3334
**I/O**
3435

pandas/_testing.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,7 @@ def assert_series_equal(
10561056
check_exact=False,
10571057
check_datetimelike_compat=False,
10581058
check_categorical=True,
1059+
check_category_order=True,
10591060
obj="Series",
10601061
):
10611062
"""
@@ -1090,6 +1091,10 @@ def assert_series_equal(
10901091
Compare datetime-like which is comparable ignoring dtype.
10911092
check_categorical : bool, default True
10921093
Whether to compare internal Categorical exactly.
1094+
check_category_order : bool, default True
1095+
Whether to compare category order of internal Categoricals
1096+
1097+
.. versionadded:: 1.0.2
10931098
obj : str, default 'Series'
10941099
Specify object name being compared, internally used to show appropriate
10951100
assertion message.
@@ -1192,7 +1197,12 @@ def assert_series_equal(
11921197

11931198
if check_categorical:
11941199
if is_categorical_dtype(left) or is_categorical_dtype(right):
1195-
assert_categorical_equal(left.values, right.values, obj=f"{obj} category")
1200+
assert_categorical_equal(
1201+
left.values,
1202+
right.values,
1203+
obj=f"{obj} category",
1204+
check_category_order=check_category_order,
1205+
)
11961206

11971207

11981208
# This could be refactored to use the NDFrame.equals method

pandas/core/arrays/categorical.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -2447,18 +2447,30 @@ def replace(self, to_replace, value, inplace: bool = False):
24472447
"""
24482448
inplace = validate_bool_kwarg(inplace, "inplace")
24492449
cat = self if inplace else self.copy()
2450-
if to_replace in cat.categories:
2451-
if isna(value):
2452-
cat.remove_categories(to_replace, inplace=True)
2453-
else:
2450+
2451+
# build a dict of (to replace -> value) pairs
2452+
if is_list_like(to_replace):
2453+
# if to_replace is list-like and value is scalar
2454+
replace_dict = {replace_value: value for replace_value in to_replace}
2455+
else:
2456+
# if both to_replace and value are scalar
2457+
replace_dict = {to_replace: value}
2458+
2459+
# other cases, like if both to_replace and value are list-like or if
2460+
# to_replace is a dict, are handled separately in NDFrame
2461+
for replace_value, new_value in replace_dict.items():
2462+
if replace_value in cat.categories:
2463+
if isna(new_value):
2464+
cat.remove_categories(replace_value, inplace=True)
2465+
continue
24542466
categories = cat.categories.tolist()
2455-
index = categories.index(to_replace)
2456-
if value in cat.categories:
2457-
value_index = categories.index(value)
2467+
index = categories.index(replace_value)
2468+
if new_value in cat.categories:
2469+
value_index = categories.index(new_value)
24582470
cat._codes[cat._codes == index] = value_index
2459-
cat.remove_categories(to_replace, inplace=True)
2471+
cat.remove_categories(replace_value, inplace=True)
24602472
else:
2461-
categories[index] = value
2473+
categories[index] = new_value
24622474
cat.rename_categories(categories, inplace=True)
24632475
if not inplace:
24642476
return cat
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
3+
import pandas as pd
4+
import pandas._testing as tm
5+
6+
7+
@pytest.mark.parametrize(
8+
"to_replace,value,expected,check_types,check_categorical",
9+
[
10+
# one-to-one
11+
(1, 2, [2, 2, 3], True, True),
12+
(1, 4, [4, 2, 3], True, True),
13+
(4, 1, [1, 2, 3], True, True),
14+
(5, 6, [1, 2, 3], True, True),
15+
# many-to-one
16+
([1], 2, [2, 2, 3], True, True),
17+
([1, 2], 3, [3, 3, 3], True, True),
18+
([1, 2], 4, [4, 4, 3], True, True),
19+
((1, 2, 4), 5, [5, 5, 3], True, True),
20+
((5, 6), 2, [1, 2, 3], True, True),
21+
# many-to-many, handled outside of Categorical and results in separate dtype
22+
([1], [2], [2, 2, 3], False, False),
23+
([1, 4], [5, 2], [5, 2, 3], False, False),
24+
# check_categorical sorts categories, which crashes on mixed dtypes
25+
(3, "4", [1, 2, "4"], True, False),
26+
([1, 2, "3"], "5", ["5", "5", 3], True, False),
27+
],
28+
)
29+
def test_replace(to_replace, value, expected, check_types, check_categorical):
30+
# GH 31720
31+
s = pd.Series([1, 2, 3], dtype="category")
32+
result = s.replace(to_replace, value)
33+
expected = pd.Series(expected, dtype="category")
34+
s.replace(to_replace, value, inplace=True)
35+
tm.assert_series_equal(
36+
expected,
37+
result,
38+
check_dtype=check_types,
39+
check_categorical=check_categorical,
40+
check_category_order=False,
41+
)
42+
tm.assert_series_equal(
43+
expected,
44+
s,
45+
check_dtype=check_types,
46+
check_categorical=check_categorical,
47+
check_category_order=False,
48+
)

0 commit comments

Comments
 (0)