diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index 8a64aa7c609d6..d49985882f500 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -142,7 +142,7 @@ Plotting Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ -- +- Bug in :meth:`DataFrame.groupby` followed by :meth:`DataFrameGroupBy.agg` not preserving subclass type of the original DataFrame (:issue:`59667`) - Reshaping diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index c112d9b6a4b54..f65d721b8ae77 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1585,6 +1585,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) result = self._insert_inaxis_grouper(result) result.index = default_index(len(result)) + if isinstance(self.obj, type(self.obj)): + result = self.obj._constructor(result) + return result agg = aggregate diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 11b874d0b1608..9adc817324a21 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -3005,6 +3005,36 @@ def test_groupby_agg_namedagg_with_duplicate_columns(): tm.assert_frame_equal(result, expected) +class MyDataFrame(DataFrame): + @property + def _constructor(self): + return MyDataFrame + + +@pytest.mark.parametrize( + "data, agg_dict, expected", + [ + pytest.param( + {"A": [1, 1, 2, 2], "B": [1, 2, 3, 4]}, + {"B": "sum"}, + DataFrame({"B": [3, 7]}, index=Index([1, 2], name="A")), + ), + pytest.param( + {"A": [1, 1, 2, 2], "B": [1, 2, 3, 4], "C": [4, 3, 2, 1]}, + {"B": "sum", "C": "mean"}, + DataFrame({"B": [3, 7], "C": [3.5, 1.5]}, index=Index([1, 2], name="A")), + ), + ], +) +def test_groupby_agg_preserves_subclass(data, agg_dict, expected): + # GH#59667 + df = MyDataFrame(data) + result = df.groupby("A").agg(agg_dict) + + assert isinstance(result, MyDataFrame) + tm.assert_frame_equal(result, expected) + + def test_groupby_multi_index_codes(): # GH#54347 df = DataFrame(