diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 44650500e0f65..d0fa6239fc1cb 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -88,6 +88,7 @@ pandas_dtype, ) from pandas.core.dtypes.dtypes import ( + CategoricalDtype, DatetimeTZDtype, ExtensionDtype, IntervalDtype, @@ -1849,13 +1850,16 @@ def ensure_nanosecond_dtype(dtype: DtypeObj) -> DtypeObj: return dtype -def find_common_type(types: List[DtypeObj]) -> DtypeObj: +def find_common_type( + types: List[DtypeObj], promote_categorical: Optional[bool] = False +) -> DtypeObj: """ Find a common data type among the given dtypes. Parameters ---------- types : list of dtypes + promote_categorical : find if possible, a categorical dtype that fits all the dtypes Returns ------- @@ -1876,6 +1880,24 @@ def find_common_type(types: List[DtypeObj]) -> DtypeObj: if all(is_dtype_equal(first, t) for t in types[1:]): return first + # special case for categorical + if promote_categorical: + if any(is_categorical_dtype(t) for t in types): + cat_dtypes = [] + for t in types: + if isinstance(t, CategoricalDtype) and t.categories is not None: + if any(~isna(t.categories.values)): + cat_values_dtype = t.categories.values.dtype + if all( + is_categorical_dtype(x) or np.can_cast(x, cat_values_dtype) + for x in types + ): + cat_dtypes.append(t) + if len(cat_dtypes) > 0: + dtype_ref = cat_dtypes[0] + if all(is_dtype_equal(dtype, dtype_ref) for dtype in cat_dtypes[1:]): + return dtype_ref + # get unique types (dict.fromkeys is used as order-preserving set()) types = list(dict.fromkeys(types).keys()) diff --git a/pandas/core/dtypes/concat.py b/pandas/core/dtypes/concat.py index 614a637f2d904..1e27b93e5abf1 100644 --- a/pandas/core/dtypes/concat.py +++ b/pandas/core/dtypes/concat.py @@ -17,11 +17,14 @@ is_extension_array_dtype, is_sparse, ) +from pandas.core.dtypes.dtypes import CategoricalDtype from pandas.core.dtypes.generic import ( ABCCategoricalIndex, ABCSeries, ) +from pandas.core.dtypes.missing import isna +from pandas.core.algorithms import unique1d from pandas.core.arrays import ExtensionArray from pandas.core.arrays.sparse import SparseArray from pandas.core.construction import ( @@ -35,6 +38,16 @@ def _cast_to_common_type(arr: ArrayLike, dtype: DtypeObj) -> ArrayLike: Helper function for `arr.astype(common_dtype)` but handling all special cases. """ + if isinstance(dtype, CategoricalDtype): + # if casting an array to a categorical dtype, then we need to ensure + # that its unique values are predefined as categories in that dtype + unique_values = unique1d(arr[~isna(arr)]) + if any(val not in dtype.categories for val in unique_values.tolist()): + raise ValueError( + "Cannot setitem on a Categorical with a new category, " + "set the categories first" + ) + if ( is_categorical_dtype(arr.dtype) and isinstance(dtype, np.dtype) @@ -116,12 +129,18 @@ def is_nonempty(x) -> bool: all_empty = not len(non_empties) single_dtype = len({x.dtype for x in to_concat}) == 1 any_ea = any(is_extension_array_dtype(x.dtype) for x in to_concat) + first_ea = isinstance(to_concat[0], ExtensionArray) + arr_index_expansion = ( + first_ea and len(to_concat) == 2 and to_concat[1].shape[0] == 1 + ) if any_ea: # we ignore axis here, as internally concatting with EAs is always # for axis=0 if not single_dtype: - target_dtype = find_common_type([x.dtype for x in to_concat]) + target_dtype = find_common_type( + [x.dtype for x in to_concat], promote_categorical=arr_index_expansion + ) to_concat = [_cast_to_common_type(arr, target_dtype) for arr in to_concat] if isinstance(to_concat[0], ExtensionArray): diff --git a/pandas/tests/series/test_categorical.py b/pandas/tests/series/test_categorical.py new file mode 100644 index 0000000000000..9dba345c3db36 --- /dev/null +++ b/pandas/tests/series/test_categorical.py @@ -0,0 +1,54 @@ +import pytest + +import pandas as pd +from pandas import Categorical +import pandas._testing as tm + + +class TestCategoricalSeries: + def test_setitem_undefined_category_raises(self): + ser = pd.Series(Categorical(["a", "b", "c"])) + msg = ( + "Cannot setitem on a Categorical with a new category, " + "set the categories first" + ) + with pytest.raises(ValueError, match=msg): + ser.loc[2] = "d" + + def test_concat_undefined_category_raises(self): + ser = pd.Series(Categorical(["a", "b", "c"])) + msg = ( + "Cannot setitem on a Categorical with a new category, " + "set the categories first" + ) + with pytest.raises(ValueError, match=msg): + ser.loc[3] = "d" + + def test_loc_category_dtype_retention(self): + # Case 1 + df = pd.DataFrame( + { + "int": [0, 1, 2], + "cat": Categorical(["a", "b", "c"], categories=["a", "b", "c"]), + } + ) + df.loc[3] = [3, "c"] + expected = pd.DataFrame( + { + "int": [0, 1, 2, 3], + "cat": Categorical(["a", "b", "c", "c"], categories=["a", "b", "c"]), + } + ) + tm.assert_frame_equal(df, expected) + + # Case 2 + ser = pd.Series(Categorical(["a", "b", "c"])) + ser.loc[3] = "c" + expected = pd.Series(Categorical(["a", "b", "c", "c"])) + tm.assert_series_equal(ser, expected) + + # Case 3 + ser = pd.Series(Categorical([1, 2, 3])) + ser.loc[3] = 3 + expected = pd.Series(Categorical([1, 2, 3, 3])) + tm.assert_series_equal(ser, expected)