From 56847f02ad7d3e5ec2bbb63ae1f3e420f79757e9 Mon Sep 17 00:00:00 2001 From: Daniel Saxton <2658661+dsaxton@users.noreply.github.com> Date: Tue, 7 Apr 2020 05:04:14 -0500 Subject: [PATCH] Backport PR #33089 on branch 1.0.x (BUG: Don't cast nullable Boolean to float in groupby) --- doc/source/whatsnew/v1.0.4.rst | 3 ++- pandas/core/groupby/groupby.py | 19 +++++++------- pandas/tests/groupby/test_nth.py | 26 ++++++++++++++++++++ pandas/tests/resample/test_datetime_index.py | 4 +-- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/doc/source/whatsnew/v1.0.4.rst b/doc/source/whatsnew/v1.0.4.rst index 80f8976826bc8..f8366d2bcae4b 100644 --- a/doc/source/whatsnew/v1.0.4.rst +++ b/doc/source/whatsnew/v1.0.4.rst @@ -21,13 +21,14 @@ Fixed regressions - Fix performance regression in ``memory_usage(deep=True)`` for object dtype (:issue:`33012`) - Bug where :meth:`Categorical.replace` would replace with ``NaN`` whenever the new value and replacement value were equal (:issue:`33288`) - Bug where an ordered :class:`Categorical` containing only ``NaN`` values would raise rather than returning ``NaN`` when taking the minimum or maximum (:issue:`33450`) +- Bug in :meth:`DataFrameGroupBy.agg` with dictionary input losing ``ExtensionArray`` dtypes (:issue:`32194`) - .. _whatsnew_104.bug_fixes: Bug fixes ~~~~~~~~~ -- +- Bug in :meth:`SeriesGroupBy.first`, :meth:`SeriesGroupBy.last`, :meth:`SeriesGroupBy.min`, and :meth:`SeriesGroupBy.max` returning floats when applied to nullable Booleans (:issue:`33071`) - Contributors diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 57d865a051c31..478239b1bcff0 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -42,6 +42,7 @@ class providing the base-class of operations. from pandas.core.dtypes.cast import maybe_downcast_to_dtype from pandas.core.dtypes.common import ( ensure_float, + is_categorical_dtype, is_datetime64_dtype, is_extension_array_dtype, is_integer_dtype, @@ -807,15 +808,15 @@ def _try_cast(self, result, obj, numeric_only: bool = False): dtype = obj.dtype if not is_scalar(result): - if is_extension_array_dtype(dtype) and dtype.kind != "M": - # The function can return something of any type, so check - # if the type is compatible with the calling EA. - # datetime64tz is handled correctly in agg_series, - # so is excluded here. - - if len(result) and isinstance(result[0], dtype.type): - cls = dtype.construct_array_type() - result = try_cast_to_ea(cls, result, dtype=dtype) + if ( + is_extension_array_dtype(dtype) + and not is_categorical_dtype(dtype) + and dtype.kind != "M" + ): + # We have to special case categorical so as not to upcast + # things like counts back to categorical + cls = dtype.construct_array_type() + result = try_cast_to_ea(cls, result, dtype=dtype) elif numeric_only and is_numeric_dtype(dtype) or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) diff --git a/pandas/tests/groupby/test_nth.py b/pandas/tests/groupby/test_nth.py index 7b2e3f7a7e06a..947907caf5cbc 100644 --- a/pandas/tests/groupby/test_nth.py +++ b/pandas/tests/groupby/test_nth.py @@ -395,6 +395,32 @@ def test_first_last_tz_multi_column(method, ts, alpha): tm.assert_frame_equal(result, expected) +@pytest.mark.parametrize( + "values", + [ + pd.array([True, False], dtype="boolean"), + pd.array([1, 2], dtype="Int64"), + pd.to_datetime(["2020-01-01", "2020-02-01"]), + pd.to_timedelta([1, 2], unit="D"), + ], +) +@pytest.mark.parametrize("function", ["first", "last", "min", "max"]) +def test_first_last_extension_array_keeps_dtype(values, function): + # https://github.com/pandas-dev/pandas/issues/33071 + # https://github.com/pandas-dev/pandas/issues/32194 + df = DataFrame({"a": [1, 2], "b": values}) + grouped = df.groupby("a") + idx = Index([1, 2], name="a") + expected_series = Series(values, name="b", index=idx) + expected_frame = DataFrame({"b": values}, index=idx) + + result_series = getattr(grouped["b"], function)() + tm.assert_series_equal(result_series, expected_series) + + result_frame = grouped.agg({"b": function}) + tm.assert_frame_equal(result_frame, expected_frame) + + def test_nth_multi_index_as_expected(): # PR 9090, related to issue 8979 # test nth on MultiIndex diff --git a/pandas/tests/resample/test_datetime_index.py b/pandas/tests/resample/test_datetime_index.py index 3ad82b9e075a8..ab6985b11ba9a 100644 --- a/pandas/tests/resample/test_datetime_index.py +++ b/pandas/tests/resample/test_datetime_index.py @@ -122,9 +122,7 @@ def test_resample_integerarray(): result = ts.resample("3T").mean() expected = Series( - [1, 4, 7], - index=pd.date_range("1/1/2000", periods=3, freq="3T"), - dtype="float64", + [1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64", ) tm.assert_series_equal(result, expected)