From 97b5049437487a79a8446f1b1466d479f19e51a7 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Thu, 29 Sep 2022 10:40:13 +0200 Subject: [PATCH] BUG: add_categories losing dtype information --- doc/source/whatsnew/v1.6.0.rst | 2 +- pandas/core/arrays/categorical.py | 19 +++++++++++++++++-- pandas/tests/arrays/categorical/test_api.py | 20 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/doc/source/whatsnew/v1.6.0.rst b/doc/source/whatsnew/v1.6.0.rst index 9f793532e5e6b..3511f1afe8cfc 100644 --- a/doc/source/whatsnew/v1.6.0.rst +++ b/doc/source/whatsnew/v1.6.0.rst @@ -158,7 +158,7 @@ Bug fixes Categorical ^^^^^^^^^^^ -- +- Bug in :meth:`Categorical.set_categories` losing dtype information (:issue:`48812`) - Datetimelike diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index ee995a0f9d0b7..80a789828506d 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -56,7 +56,10 @@ from pandas.util._exceptions import find_stack_level from pandas.util._validators import validate_bool_kwarg -from pandas.core.dtypes.cast import coerce_indexer_dtype +from pandas.core.dtypes.cast import ( + coerce_indexer_dtype, + find_common_type, +) from pandas.core.dtypes.common import ( ensure_int64, ensure_platform_int, @@ -1292,7 +1295,19 @@ def add_categories( raise ValueError( f"new categories must not include old categories: {already_included}" ) - new_categories = list(self.dtype.categories) + list(new_categories) + + if hasattr(new_categories, "dtype"): + from pandas import Series + + dtype = find_common_type( + [self.dtype.categories.dtype, new_categories.dtype] + ) + new_categories = Series( + list(self.dtype.categories) + list(new_categories), dtype=dtype + ) + else: + new_categories = list(self.dtype.categories) + list(new_categories) + new_dtype = CategoricalDtype(new_categories, self.ordered) cat = self if inplace else self.copy() diff --git a/pandas/tests/arrays/categorical/test_api.py b/pandas/tests/arrays/categorical/test_api.py index f0669f52acee2..377ab530d8733 100644 --- a/pandas/tests/arrays/categorical/test_api.py +++ b/pandas/tests/arrays/categorical/test_api.py @@ -11,6 +11,7 @@ DataFrame, Index, Series, + StringDtype, ) import pandas._testing as tm from pandas.core.arrays.categorical import recode_for_categories @@ -237,6 +238,25 @@ def test_add_categories_existing_raises(self): with pytest.raises(ValueError, match=msg): cat.add_categories(["d"]) + def test_add_categories_losing_dtype_information(self): + # GH#48812 + cat = Categorical(Series([1, 2], dtype="Int64")) + ser = Series([4], dtype="Int64") + result = cat.add_categories(ser) + expected = Categorical( + Series([1, 2], dtype="Int64"), categories=Series([1, 2, 4], dtype="Int64") + ) + tm.assert_categorical_equal(result, expected) + + cat = Categorical(Series(["a", "b", "a"], dtype=StringDtype())) + ser = Series(["d"], dtype=StringDtype()) + result = cat.add_categories(ser) + expected = Categorical( + Series(["a", "b", "a"], dtype=StringDtype()), + categories=Series(["a", "b", "d"], dtype=StringDtype()), + ) + tm.assert_categorical_equal(result, expected) + def test_set_categories(self): cat = Categorical(["a", "b", "c", "a"], ordered=True) exp_categories = Index(["c", "b", "a"])