Skip to content

Commit c7623b3

Browse files
authored
BUG: add_categories losing dtype information (#48847)
1 parent 944ee5e commit c7623b3

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

doc/source/whatsnew/v1.6.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ Bug fixes
160160

161161
Categorical
162162
^^^^^^^^^^^
163-
-
163+
- Bug in :meth:`Categorical.set_categories` losing dtype information (:issue:`48812`)
164164
-
165165

166166
Datetimelike

pandas/core/arrays/categorical.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@
5656
from pandas.util._exceptions import find_stack_level
5757
from pandas.util._validators import validate_bool_kwarg
5858

59-
from pandas.core.dtypes.cast import coerce_indexer_dtype
59+
from pandas.core.dtypes.cast import (
60+
coerce_indexer_dtype,
61+
find_common_type,
62+
)
6063
from pandas.core.dtypes.common import (
6164
ensure_int64,
6265
ensure_platform_int,
@@ -1292,7 +1295,19 @@ def add_categories(
12921295
raise ValueError(
12931296
f"new categories must not include old categories: {already_included}"
12941297
)
1295-
new_categories = list(self.dtype.categories) + list(new_categories)
1298+
1299+
if hasattr(new_categories, "dtype"):
1300+
from pandas import Series
1301+
1302+
dtype = find_common_type(
1303+
[self.dtype.categories.dtype, new_categories.dtype]
1304+
)
1305+
new_categories = Series(
1306+
list(self.dtype.categories) + list(new_categories), dtype=dtype
1307+
)
1308+
else:
1309+
new_categories = list(self.dtype.categories) + list(new_categories)
1310+
12961311
new_dtype = CategoricalDtype(new_categories, self.ordered)
12971312

12981313
cat = self if inplace else self.copy()

pandas/tests/arrays/categorical/test_api.py

+20
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
DataFrame,
1212
Index,
1313
Series,
14+
StringDtype,
1415
)
1516
import pandas._testing as tm
1617
from pandas.core.arrays.categorical import recode_for_categories
@@ -237,6 +238,25 @@ def test_add_categories_existing_raises(self):
237238
with pytest.raises(ValueError, match=msg):
238239
cat.add_categories(["d"])
239240

241+
def test_add_categories_losing_dtype_information(self):
242+
# GH#48812
243+
cat = Categorical(Series([1, 2], dtype="Int64"))
244+
ser = Series([4], dtype="Int64")
245+
result = cat.add_categories(ser)
246+
expected = Categorical(
247+
Series([1, 2], dtype="Int64"), categories=Series([1, 2, 4], dtype="Int64")
248+
)
249+
tm.assert_categorical_equal(result, expected)
250+
251+
cat = Categorical(Series(["a", "b", "a"], dtype=StringDtype()))
252+
ser = Series(["d"], dtype=StringDtype())
253+
result = cat.add_categories(ser)
254+
expected = Categorical(
255+
Series(["a", "b", "a"], dtype=StringDtype()),
256+
categories=Series(["a", "b", "d"], dtype=StringDtype()),
257+
)
258+
tm.assert_categorical_equal(result, expected)
259+
240260
def test_set_categories(self):
241261
cat = Categorical(["a", "b", "c", "a"], ordered=True)
242262
exp_categories = Index(["c", "b", "a"])

0 commit comments

Comments
 (0)