diff --git a/doc/source/whatsnew/v1.0.4.rst b/doc/source/whatsnew/v1.0.4.rst index d4f4e9c0ae8de..4f122b0dea5f4 100644 --- a/doc/source/whatsnew/v1.0.4.rst +++ b/doc/source/whatsnew/v1.0.4.rst @@ -21,6 +21,7 @@ Fixed regressions - Fix performance regression in ``memory_usage(deep=True)`` for object dtype (:issue:`33012`) - Bug where :meth:`Categorical.replace` would replace with ``NaN`` whenever the new value and replacement value were equal (:issue:`33288`) - Bug where an ordered :class:`Categorical` containing only ``NaN`` values would raise rather than returning ``NaN`` when taking the minimum or maximum (:issue:`33450`) +- Bug in :meth:`DataFrameGroupBy.agg` with dictionary input losing ``ExtensionArray`` dtypes (:issue:`32194`) - Fix to preserve the ability to index with the "nearest" method with xarray's CFTimeIndex, an :class:`Index` subclass (`pydata/xarray#3751 `_, :issue:`32905`). - @@ -28,6 +29,7 @@ Fixed regressions Bug fixes ~~~~~~~~~ +- Bug in :meth:`SeriesGroupBy.first`, :meth:`SeriesGroupBy.last`, :meth:`SeriesGroupBy.min`, and :meth:`SeriesGroupBy.max` returning floats when applied to nullable Booleans (:issue:`33071`) - Bug in :meth:`Rolling.min` and :meth:`Rolling.max`: Growing memory usage after multiple calls when using a fixed window (:issue:`30726`) - diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 57d865a051c31..478239b1bcff0 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -42,6 +42,7 @@ class providing the base-class of operations. from pandas.core.dtypes.cast import maybe_downcast_to_dtype from pandas.core.dtypes.common import ( ensure_float, + is_categorical_dtype, is_datetime64_dtype, is_extension_array_dtype, is_integer_dtype, @@ -807,15 +808,15 @@ def _try_cast(self, result, obj, numeric_only: bool = False): dtype = obj.dtype if not is_scalar(result): - if is_extension_array_dtype(dtype) and dtype.kind != "M": - # The function can return something of any type, so check - # if the type is compatible with the calling EA. - # datetime64tz is handled correctly in agg_series, - # so is excluded here. - - if len(result) and isinstance(result[0], dtype.type): - cls = dtype.construct_array_type() - result = try_cast_to_ea(cls, result, dtype=dtype) + if ( + is_extension_array_dtype(dtype) + and not is_categorical_dtype(dtype) + and dtype.kind != "M" + ): + # We have to special case categorical so as not to upcast + # things like counts back to categorical + cls = dtype.construct_array_type() + result = try_cast_to_ea(cls, result, dtype=dtype) elif numeric_only and is_numeric_dtype(dtype) or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index 7b2e3f7a7e06a..947907caf5cbc 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -395,6 +395,32 @@ def test_first_last_tz_multi_column(method, ts, alpha): tm.assert_frame_equal(result, expected) +@pytest.mark.parametrize( + "values", + [ + pd.array([True, False], dtype="boolean"), + pd.array([1, 2], dtype="Int64"), + pd.to_datetime(["2020-01-01", "2020-02-01"]), + pd.to_timedelta([1, 2], unit="D"), + ], +) +@pytest.mark.parametrize("function", ["first", "last", "min", "max"]) +def test_first_last_extension_array_keeps_dtype(values, function): + # https://github.com/pandas-dev/pandas/issues/33071 + # https://github.com/pandas-dev/pandas/issues/32194 + df = DataFrame({"a": [1, 2], "b": values}) + grouped = df.groupby("a") + idx = Index([1, 2], name="a") + expected_series = Series(values, name="b", index=idx) + expected_frame = DataFrame({"b": values}, index=idx) + + result_series = getattr(grouped["b"], function)() + tm.assert_series_equal(result_series, expected_series) + + result_frame = grouped.agg({"b": function}) + tm.assert_frame_equal(result_frame, expected_frame) + + def test_nth_multi_index_as_expected(): # PR 9090, related to issue 8979 # test nth on MultiIndex diff --git a/pandas/tests/resample/test_datetime_index.py b/pandas/tests/resample/test_datetime_index.py index 3ad82b9e075a8..ab6985b11ba9a 100644 --- a/pandas/tests/resample/test_datetime_index.py +++ b/pandas/tests/resample/test_datetime_index.py @@ -122,9 +122,7 @@ def test_resample_integerarray(): result = ts.resample("3T").mean() expected = Series( - [1, 4, 7], - index=pd.date_range("1/1/2000", periods=3, freq="3T"), - dtype="float64", + [1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64", ) tm.assert_series_equal(result, expected)