Skip to content

Commit fd0d736

Browse files
committed
BUG: list-like to_replace on Categorical.replace is ignored or crash
1 parent 980ab6b commit fd0d736

File tree

4 files changed

+85
-13
lines changed

4 files changed

+85
-13
lines changed

doc/source/whatsnew/v1.0.2.rst

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ Fixed regressions
2626
Bug fixes
2727
~~~~~~~~~
2828

29+
**Categorical**
30+
31+
- Bug in :class:`Categorical` that would ignore or crash when calling :meth:`Series.replace` with a list-like ``to_replace`` (:issue:`31720`)
32+
2933
**I/O**
3034

3135
- Using ``pd.NA`` with :meth:`DataFrame.to_json` now correctly outputs a null value instead of an empty object (:issue:`31615`)

pandas/_testing.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ def assert_series_equal(
10701070
check_exact=False,
10711071
check_datetimelike_compat=False,
10721072
check_categorical=True,
1073+
check_category_order=True,
10731074
obj="Series",
10741075
):
10751076
"""
@@ -1104,6 +1105,8 @@ def assert_series_equal(
11041105
Compare datetime-like which is comparable ignoring dtype.
11051106
check_categorical : bool, default True
11061107
Whether to compare internal Categorical exactly.
1108+
check_category_order : bool, default True
1109+
Whether to compare category order of internal Categoricals
11071110
obj : str, default 'Series'
11081111
Specify object name being compared, internally used to show appropriate
11091112
assertion message.
@@ -1206,7 +1209,12 @@ def assert_series_equal(
12061209

12071210
if check_categorical:
12081211
if is_categorical_dtype(left) or is_categorical_dtype(right):
1209-
assert_categorical_equal(left.values, right.values, obj=f"{obj} category")
1212+
assert_categorical_equal(
1213+
left.values,
1214+
right.values,
1215+
obj=f"{obj} category",
1216+
check_category_order=check_category_order,
1217+
)
12101218

12111219

12121220
# This could be refactored to use the NDFrame.equals method

pandas/core/arrays/categorical.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -2441,19 +2441,31 @@ def replace(self, to_replace, value, inplace: bool = False):
24412441
"""
24422442
inplace = validate_bool_kwarg(inplace, "inplace")
24432443
cat = self if inplace else self.copy()
2444-
if to_replace in cat.categories:
2445-
if isna(value):
2446-
cat.remove_categories(to_replace, inplace=True)
2447-
else:
2448-
categories = cat.categories.tolist()
2449-
index = categories.index(to_replace)
2450-
if value in cat.categories:
2451-
value_index = categories.index(value)
2452-
cat._codes[cat._codes == index] = value_index
2453-
cat.remove_categories(to_replace, inplace=True)
2444+
2445+
# build a dict of (to replace -> value) pairs
2446+
if is_list_like(to_replace):
2447+
# if to_replace is list-like and value is scalar
2448+
replace_dict = {replace_value: value for replace_value in to_replace}
2449+
else:
2450+
# if both to_replace and value are scalar
2451+
replace_dict = {to_replace: value}
2452+
# other cases, like if both to_replace and value are list-like or if
2453+
# to_replace is a dict, are handled separately in NDFrame
2454+
2455+
for replace_value, new_value in replace_dict.items():
2456+
if replace_value in cat.categories:
2457+
if isna(new_value):
2458+
cat.remove_categories(replace_value, inplace=True)
24542459
else:
2455-
categories[index] = value
2456-
cat.rename_categories(categories, inplace=True)
2460+
categories = cat.categories.tolist()
2461+
index = categories.index(replace_value)
2462+
if new_value in cat.categories:
2463+
value_index = categories.index(new_value)
2464+
cat._codes[cat._codes == index] = value_index
2465+
cat.remove_categories(replace_value, inplace=True)
2466+
else:
2467+
categories[index] = new_value
2468+
cat.rename_categories(categories, inplace=True)
24572469
if not inplace:
24582470
return cat
24592471

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
3+
import pandas as pd
4+
import pandas._testing as tm
5+
6+
7+
@pytest.mark.parametrize(
8+
"to_replace,value,expected,check_types,check_categorical",
9+
[
10+
# one-to-one
11+
(1, 2, [2, 2, 3], True, True),
12+
(1, 4, [4, 2, 3], True, True),
13+
(4, 1, [1, 2, 3], True, True),
14+
(5, 6, [1, 2, 3], True, True),
15+
# many-to-one
16+
([1], 2, [2, 2, 3], True, True),
17+
([1, 2], 3, [3, 3, 3], True, True),
18+
([1, 2], 4, [4, 4, 3], True, True),
19+
((1, 2, 4), 5, [5, 5, 3], True, True),
20+
((5, 6), 2, [1, 2, 3], True, True),
21+
# many-to-many, handled outside of Categorical and results in separate dtype
22+
([1], [2], [2, 2, 3], False, False),
23+
([1, 4], [5, 2], [5, 2, 3], False, False),
24+
# check_categorical sorts categories, which crashes on mixed dtypes
25+
(3, "4", [1, 2, "4"], True, False),
26+
([1, 2, "3"], "5", ["5", "5", 3], True, False),
27+
],
28+
)
29+
def test_replace(to_replace, value, expected, check_types, check_categorical):
30+
# GH 31720
31+
s = pd.Series([1, 2, 3], dtype="category")
32+
result = s.replace(to_replace, value)
33+
expected = pd.Series(expected, dtype="category")
34+
s.replace(to_replace, value, inplace=True)
35+
tm.assert_series_equal(
36+
expected,
37+
result,
38+
check_dtype=check_types,
39+
check_categorical=check_categorical,
40+
check_category_order=False,
41+
)
42+
tm.assert_series_equal(
43+
expected,
44+
s,
45+
check_dtype=check_types,
46+
check_categorical=check_categorical,
47+
check_category_order=False,
48+
)

0 commit comments

Comments
 (0)