Skip to content

Commit 8444453

Browse files
BUG: list-like to_replace on Categorical.replace is ignored or crash (#31734)
1 parent 267d2d8 commit 8444453

File tree

4 files changed

+81
-10
lines changed

4 files changed

+81
-10
lines changed

doc/source/whatsnew/v1.0.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Bug fixes
3131
**Categorical**
3232

3333
- Fixed bug where :meth:`Categorical.from_codes` improperly raised a ``ValueError`` when passed nullable integer codes. (:issue:`31779`)
34+
- Bug in :class:`Categorical` that would ignore or crash when calling :meth:`Series.replace` with a list-like ``to_replace`` (:issue:`31720`)
3435

3536
**I/O**
3637

pandas/_testing.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,7 @@ def assert_series_equal(
10741074
check_exact=False,
10751075
check_datetimelike_compat=False,
10761076
check_categorical=True,
1077+
check_category_order=True,
10771078
obj="Series",
10781079
):
10791080
"""
@@ -1108,6 +1109,10 @@ def assert_series_equal(
11081109
Compare datetime-like which is comparable ignoring dtype.
11091110
check_categorical : bool, default True
11101111
Whether to compare internal Categorical exactly.
1112+
check_category_order : bool, default True
1113+
Whether to compare category order of internal Categoricals
1114+
1115+
.. versionadded:: 1.0.2
11111116
obj : str, default 'Series'
11121117
Specify object name being compared, internally used to show appropriate
11131118
assertion message.
@@ -1210,7 +1215,12 @@ def assert_series_equal(
12101215

12111216
if check_categorical:
12121217
if is_categorical_dtype(left) or is_categorical_dtype(right):
1213-
assert_categorical_equal(left.values, right.values, obj=f"{obj} category")
1218+
assert_categorical_equal(
1219+
left.values,
1220+
right.values,
1221+
obj=f"{obj} category",
1222+
check_category_order=check_category_order,
1223+
)
12141224

12151225

12161226
# This could be refactored to use the NDFrame.equals method

pandas/core/arrays/categorical.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -2441,18 +2441,30 @@ def replace(self, to_replace, value, inplace: bool = False):
24412441
"""
24422442
inplace = validate_bool_kwarg(inplace, "inplace")
24432443
cat = self if inplace else self.copy()
2444-
if to_replace in cat.categories:
2445-
if isna(value):
2446-
cat.remove_categories(to_replace, inplace=True)
2447-
else:
2444+
2445+
# build a dict of (to replace -> value) pairs
2446+
if is_list_like(to_replace):
2447+
# if to_replace is list-like and value is scalar
2448+
replace_dict = {replace_value: value for replace_value in to_replace}
2449+
else:
2450+
# if both to_replace and value are scalar
2451+
replace_dict = {to_replace: value}
2452+
2453+
# other cases, like if both to_replace and value are list-like or if
2454+
# to_replace is a dict, are handled separately in NDFrame
2455+
for replace_value, new_value in replace_dict.items():
2456+
if replace_value in cat.categories:
2457+
if isna(new_value):
2458+
cat.remove_categories(replace_value, inplace=True)
2459+
continue
24482460
categories = cat.categories.tolist()
2449-
index = categories.index(to_replace)
2450-
if value in cat.categories:
2451-
value_index = categories.index(value)
2461+
index = categories.index(replace_value)
2462+
if new_value in cat.categories:
2463+
value_index = categories.index(new_value)
24522464
cat._codes[cat._codes == index] = value_index
2453-
cat.remove_categories(to_replace, inplace=True)
2465+
cat.remove_categories(replace_value, inplace=True)
24542466
else:
2455-
categories[index] = value
2467+
categories[index] = new_value
24562468
cat.rename_categories(categories, inplace=True)
24572469
if not inplace:
24582470
return cat
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
3+
import pandas as pd
4+
import pandas._testing as tm
5+
6+
7+
@pytest.mark.parametrize(
8+
"to_replace,value,expected,check_types,check_categorical",
9+
[
10+
# one-to-one
11+
(1, 2, [2, 2, 3], True, True),
12+
(1, 4, [4, 2, 3], True, True),
13+
(4, 1, [1, 2, 3], True, True),
14+
(5, 6, [1, 2, 3], True, True),
15+
# many-to-one
16+
([1], 2, [2, 2, 3], True, True),
17+
([1, 2], 3, [3, 3, 3], True, True),
18+
([1, 2], 4, [4, 4, 3], True, True),
19+
((1, 2, 4), 5, [5, 5, 3], True, True),
20+
((5, 6), 2, [1, 2, 3], True, True),
21+
# many-to-many, handled outside of Categorical and results in separate dtype
22+
([1], [2], [2, 2, 3], False, False),
23+
([1, 4], [5, 2], [5, 2, 3], False, False),
24+
# check_categorical sorts categories, which crashes on mixed dtypes
25+
(3, "4", [1, 2, "4"], True, False),
26+
([1, 2, "3"], "5", ["5", "5", 3], True, False),
27+
],
28+
)
29+
def test_replace(to_replace, value, expected, check_types, check_categorical):
30+
# GH 31720
31+
s = pd.Series([1, 2, 3], dtype="category")
32+
result = s.replace(to_replace, value)
33+
expected = pd.Series(expected, dtype="category")
34+
s.replace(to_replace, value, inplace=True)
35+
tm.assert_series_equal(
36+
expected,
37+
result,
38+
check_dtype=check_types,
39+
check_categorical=check_categorical,
40+
check_category_order=False,
41+
)
42+
tm.assert_series_equal(
43+
expected,
44+
s,
45+
check_dtype=check_types,
46+
check_categorical=check_categorical,
47+
check_category_order=False,
48+
)

0 commit comments

Comments
 (0)