diff --git a/doc/source/whatsnew/v1.0.2.rst b/doc/source/whatsnew/v1.0.2.rst index 0c012e5c1417b..805f87b63eef8 100644 --- a/doc/source/whatsnew/v1.0.2.rst +++ b/doc/source/whatsnew/v1.0.2.rst @@ -29,6 +29,7 @@ Bug fixes **Categorical** - Fixed bug where :meth:`Categorical.from_codes` improperly raised a ``ValueError`` when passed nullable integer codes. (:issue:`31779`) +- Bug in :class:`Categorical` that would ignore or crash when calling :meth:`Series.replace` with a list-like ``to_replace`` (:issue:`31720`) **I/O** diff --git a/pandas/_testing.py b/pandas/_testing.py index 1fdc5d478aaf6..ca378e5ce8f77 100644 --- a/pandas/_testing.py +++ b/pandas/_testing.py @@ -1056,6 +1056,7 @@ def assert_series_equal( check_exact=False, check_datetimelike_compat=False, check_categorical=True, + check_category_order=True, obj="Series", ): """ @@ -1090,6 +1091,10 @@ def assert_series_equal( Compare datetime-like which is comparable ignoring dtype. check_categorical : bool, default True Whether to compare internal Categorical exactly. + check_category_order : bool, default True + Whether to compare category order of internal Categoricals + + .. versionadded:: 1.0.2 obj : str, default 'Series' Specify object name being compared, internally used to show appropriate assertion message. @@ -1192,7 +1197,12 @@ def assert_series_equal( if check_categorical: if is_categorical_dtype(left) or is_categorical_dtype(right): - assert_categorical_equal(left.values, right.values, obj=f"{obj} category") + assert_categorical_equal( + left.values, + right.values, + obj=f"{obj} category", + check_category_order=check_category_order, + ) # This could be refactored to use the NDFrame.equals method diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 52d9df0c2d508..d8a2fbdd58382 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2447,18 +2447,30 @@ def replace(self, to_replace, value, inplace: bool = False): """ inplace = validate_bool_kwarg(inplace, "inplace") cat = self if inplace else self.copy() - if to_replace in cat.categories: - if isna(value): - cat.remove_categories(to_replace, inplace=True) - else: + + # build a dict of (to replace -> value) pairs + if is_list_like(to_replace): + # if to_replace is list-like and value is scalar + replace_dict = {replace_value: value for replace_value in to_replace} + else: + # if both to_replace and value are scalar + replace_dict = {to_replace: value} + + # other cases, like if both to_replace and value are list-like or if + # to_replace is a dict, are handled separately in NDFrame + for replace_value, new_value in replace_dict.items(): + if replace_value in cat.categories: + if isna(new_value): + cat.remove_categories(replace_value, inplace=True) + continue categories = cat.categories.tolist() - index = categories.index(to_replace) - if value in cat.categories: - value_index = categories.index(value) + index = categories.index(replace_value) + if new_value in cat.categories: + value_index = categories.index(new_value) cat._codes[cat._codes == index] = value_index - cat.remove_categories(to_replace, inplace=True) + cat.remove_categories(replace_value, inplace=True) else: - categories[index] = value + categories[index] = new_value cat.rename_categories(categories, inplace=True) if not inplace: return cat diff --git a/pandas/tests/arrays/categorical/test_replace.py b/pandas/tests/arrays/categorical/test_replace.py new file mode 100644 index 0000000000000..52530123bd52f --- /dev/null +++ b/pandas/tests/arrays/categorical/test_replace.py @@ -0,0 +1,48 @@ +import pytest + +import pandas as pd +import pandas._testing as tm + + +@pytest.mark.parametrize( + "to_replace,value,expected,check_types,check_categorical", + [ + # one-to-one + (1, 2, [2, 2, 3], True, True), + (1, 4, [4, 2, 3], True, True), + (4, 1, [1, 2, 3], True, True), + (5, 6, [1, 2, 3], True, True), + # many-to-one + ([1], 2, [2, 2, 3], True, True), + ([1, 2], 3, [3, 3, 3], True, True), + ([1, 2], 4, [4, 4, 3], True, True), + ((1, 2, 4), 5, [5, 5, 3], True, True), + ((5, 6), 2, [1, 2, 3], True, True), + # many-to-many, handled outside of Categorical and results in separate dtype + ([1], [2], [2, 2, 3], False, False), + ([1, 4], [5, 2], [5, 2, 3], False, False), + # check_categorical sorts categories, which crashes on mixed dtypes + (3, "4", [1, 2, "4"], True, False), + ([1, 2, "3"], "5", ["5", "5", 3], True, False), + ], +) +def test_replace(to_replace, value, expected, check_types, check_categorical): + # GH 31720 + s = pd.Series([1, 2, 3], dtype="category") + result = s.replace(to_replace, value) + expected = pd.Series(expected, dtype="category") + s.replace(to_replace, value, inplace=True) + tm.assert_series_equal( + expected, + result, + check_dtype=check_types, + check_categorical=check_categorical, + check_category_order=False, + ) + tm.assert_series_equal( + expected, + s, + check_dtype=check_types, + check_categorical=check_categorical, + check_category_order=False, + )