From 27cc0e20e88a1608eadb83de2af9713c081f8d92 Mon Sep 17 00:00:00 2001 From: AshmitGupta <67188586+AshmitGupta@users.noreply.github.com> Date: Sat, 31 Aug 2024 22:18:58 -0700 Subject: [PATCH 1/3] FIX: Preserve DataFrame subclass type in groupby().agg() --- pandas/core/groupby/generic.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index c112d9b6a4b54..f65d721b8ae77 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1585,6 +1585,9 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) result = self._insert_inaxis_grouper(result) result.index = default_index(len(result)) + if isinstance(self.obj, type(self.obj)): + result = self.obj._constructor(result) + return result agg = aggregate From 14de84b240aff6a660f5b354b8f74701b5bf1051 Mon Sep 17 00:00:00 2001 From: AshmitGupta <67188586+AshmitGupta@users.noreply.github.com> Date: Sat, 31 Aug 2024 23:18:56 -0700 Subject: [PATCH 2/3] Added tests and whats new --- doc/source/whatsnew/v2.3.0.rst | 2 +- pandas/tests/groupby/test_groupby.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index 8a64aa7c609d6..d49985882f500 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -142,7 +142,7 @@ Plotting Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ -- +- Bug in :meth:`DataFrame.groupby` followed by :meth:`DataFrameGroupBy.agg` not preserving subclass type of the original DataFrame (:issue:`59667`) - Reshaping diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 11b874d0b1608..54df1bdb3ffd3 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -3004,6 +3004,30 @@ def test_groupby_agg_namedagg_with_duplicate_columns(): tm.assert_frame_equal(result, expected) +class MyDataFrame(DataFrame): + @property + def _constructor(self): + return MyDataFrame + +@pytest.mark.parametrize("data, agg_dict, expected", [ + pytest.param( + {"A": [1, 1, 2, 2], "B": [1, 2, 3, 4]}, + {"B": "sum"}, + DataFrame({"B": [3, 7]}, index=Index([1, 2], name="A")) + ), + pytest.param( + {"A": [1, 1, 2, 2], "B": [1, 2, 3, 4], "C": [4, 3, 2, 1]}, + {"B": "sum", "C": "mean"}, + DataFrame({"B": [3, 7], "C": [3.5, 1.5]}, index=Index([1, 2], name="A")) + ), +]) +def test_groupby_agg_preserves_subclass(data, agg_dict, expected): + # GH#59667 + df = MyDataFrame(data) + result = df.groupby("A").agg(agg_dict) + + assert isinstance(result, MyDataFrame) + tm.assert_frame_equal(result, expected) def test_groupby_multi_index_codes(): # GH#54347 From 208d2fced0120fa7cffb42dcddda5fc2c294e035 Mon Sep 17 00:00:00 2001 From: AshmitGupta <67188586+AshmitGupta@users.noreply.github.com> Date: Sun, 1 Sep 2024 01:48:13 -0700 Subject: [PATCH 3/3] Code style chore (Pre-commit hook) --- pandas/tests/groupby/test_groupby.py | 32 +++++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 54df1bdb3ffd3..9adc817324a21 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -3004,31 +3004,37 @@ def test_groupby_agg_namedagg_with_duplicate_columns(): tm.assert_frame_equal(result, expected) + class MyDataFrame(DataFrame): @property def _constructor(self): return MyDataFrame -@pytest.mark.parametrize("data, agg_dict, expected", [ - pytest.param( - {"A": [1, 1, 2, 2], "B": [1, 2, 3, 4]}, - {"B": "sum"}, - DataFrame({"B": [3, 7]}, index=Index([1, 2], name="A")) - ), - pytest.param( - {"A": [1, 1, 2, 2], "B": [1, 2, 3, 4], "C": [4, 3, 2, 1]}, - {"B": "sum", "C": "mean"}, - DataFrame({"B": [3, 7], "C": [3.5, 1.5]}, index=Index([1, 2], name="A")) - ), -]) + +@pytest.mark.parametrize( + "data, agg_dict, expected", + [ + pytest.param( + {"A": [1, 1, 2, 2], "B": [1, 2, 3, 4]}, + {"B": "sum"}, + DataFrame({"B": [3, 7]}, index=Index([1, 2], name="A")), + ), + pytest.param( + {"A": [1, 1, 2, 2], "B": [1, 2, 3, 4], "C": [4, 3, 2, 1]}, + {"B": "sum", "C": "mean"}, + DataFrame({"B": [3, 7], "C": [3.5, 1.5]}, index=Index([1, 2], name="A")), + ), + ], +) def test_groupby_agg_preserves_subclass(data, agg_dict, expected): # GH#59667 df = MyDataFrame(data) result = df.groupby("A").agg(agg_dict) - + assert isinstance(result, MyDataFrame) tm.assert_frame_equal(result, expected) + def test_groupby_multi_index_codes(): # GH#54347 df = DataFrame(