diff --git a/doc/source/whatsnew/v1.5.0.rst b/doc/source/whatsnew/v1.5.0.rst index 502e37705abfb..aace7ae221304 100644 --- a/doc/source/whatsnew/v1.5.0.rst +++ b/doc/source/whatsnew/v1.5.0.rst @@ -962,6 +962,7 @@ Indexing - Bug in :meth:`DataFrame.sum` min_count changes dtype if input contains NaNs (:issue:`46947`) - Bug in :class:`IntervalTree` that lead to an infinite recursion. (:issue:`46658`) - Bug in :class:`PeriodIndex` raising ``AttributeError`` when indexing on ``NA``, rather than putting ``NaT`` in its place. (:issue:`46673`) +- Bug in :meth:`DataFrame.loc` when enlarging a :class:`Series` with dtype :class:`CategoricalDtype` with a scalar (:issue:`47677`) - Missing diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 769656d1c4755..00b18ef154f27 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -591,7 +591,9 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan): return dtype, fv elif isna(fill_value): - dtype = _dtype_obj + # preserve dtype in case of categoricaldtype + if not isinstance(dtype, CategoricalDtype): + dtype = _dtype_obj if fill_value is None: # but we retain e.g. pd.NA fill_value = np.nan @@ -646,6 +648,12 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan): return np.dtype("object"), fill_value + elif isinstance(dtype, CategoricalDtype): + if fill_value in dtype.categories: + return dtype, fill_value + else: + return object, ensure_object(fill_value) + elif is_float(fill_value): if issubclass(dtype.type, np.bool_): dtype = np.dtype(np.object_) diff --git a/pandas/tests/indexing/test_loc.py b/pandas/tests/indexing/test_loc.py index 4c38a2219372d..60c4ee5518047 100644 --- a/pandas/tests/indexing/test_loc.py +++ b/pandas/tests/indexing/test_loc.py @@ -18,6 +18,7 @@ import pandas as pd from pandas import ( Categorical, + CategoricalDtype, CategoricalIndex, DataFrame, DatetimeIndex, @@ -1820,6 +1821,46 @@ def test_loc_getitem_sorted_index_level_with_duplicates(self): result = df.loc[("foo", "bar")] tm.assert_frame_equal(result, expected) + def test_additional_element_to_categorical_series_loc(self): + # GH#47677 + result = Series(["a", "b", "c"], dtype="category") + result.loc[3] = 0 + expected = Series(["a", "b", "c", 0], dtype="object") + tm.assert_series_equal(result, expected) + + def test_additional_categorical_element_loc(self): + # GH#47677 + result = Series(["a", "b", "c"], dtype="category") + result.loc[3] = "a" + expected = Series(["a", "b", "c", "a"], dtype="category") + tm.assert_series_equal(result, expected) + + def test_loc_set_nan_in_categorical_series(self, any_numeric_ea_dtype): + # GH#47677 + srs = Series( + [1, 2, 3], + dtype=CategoricalDtype(Index([1, 2, 3], dtype=any_numeric_ea_dtype)), + ) + # enlarge + srs.loc[3] = np.nan + assert srs.values.dtype._categories.dtype == any_numeric_ea_dtype + # set into + srs.loc[1] = np.nan + assert srs.values.dtype._categories.dtype == any_numeric_ea_dtype + + @pytest.mark.parametrize("na", (np.nan, pd.NA, None)) + def test_loc_consistency_series_enlarge_set_into(self, na): + # GH#47677 + srs_enlarge = Series(["a", "b", "c"], dtype="category") + srs_enlarge.loc[3] = na + + srs_setinto = Series(["a", "b", "c", "a"], dtype="category") + srs_setinto.loc[3] = na + + tm.assert_series_equal(srs_enlarge, srs_setinto) + expected = Series(["a", "b", "c", na], dtype="category") + tm.assert_series_equal(srs_enlarge, expected) + def test_loc_getitem_preserves_index_level_category_dtype(self): # GH#15166 df = DataFrame(