diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index a8d6f3fce5bb7..2fba47d41f539 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -1377,6 +1377,7 @@ Groupby/resample/rolling - Bug in :meth:`.DataFrameGroupBy.agg` with ``engine="numba"`` failing to respect ``as_index=False`` (:issue:`51228`) - Bug in :meth:`.DataFrameGroupBy.agg`, :meth:`.SeriesGroupBy.agg`, and :meth:`.Resampler.agg` would ignore arguments when passed a list of functions (:issue:`50863`) - Bug in :meth:`.DataFrameGroupBy.ohlc` ignoring ``as_index=False`` (:issue:`51413`) +- Bug in :meth:`DataFrameGroupBy.agg` after subsetting columns (e.g. ``.groupby(...)[["a", "b"]]``) would not include groupings in the result (:issue:`51186`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 67188d91bca70..499bef2b61046 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1359,21 +1359,15 @@ def _python_agg_general(self, func, *args, **kwargs): return self._wrap_aggregated_output(res) def _iterate_slices(self) -> Iterable[Series]: - obj = self._selected_obj + obj = self._obj_with_exclusions if self.axis == 1: obj = obj.T - if isinstance(obj, Series) and obj.name not in self.exclusions: + if isinstance(obj, Series): # Occurs when doing DataFrameGroupBy(...)["X"] yield obj else: for label, values in obj.items(): - if label in self.exclusions: - # Note: if we tried to just iterate over _obj_with_exclusions, - # we would break test_wrap_agg_out by yielding a column - # that is skipped here but not dropped from obj_with_exclusions - continue - yield values def _aggregate_frame(self, func, *args, **kwargs) -> DataFrame: diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 6c591616e8469..d658de4a7d7c3 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -321,8 +321,8 @@ def func(ser): with pytest.raises(TypeError, match="Test error message"): grouped.aggregate(func) - result = grouped[[c for c in three_group if c != "C"]].aggregate(func) - exp_grouped = three_group.loc[:, three_group.columns != "C"] + result = grouped[["D", "E", "F"]].aggregate(func) + exp_grouped = three_group.loc[:, ["A", "B", "D", "E", "F"]] expected = exp_grouped.groupby(["A", "B"]).aggregate(func) tm.assert_frame_equal(result, expected) @@ -1521,3 +1521,16 @@ def foo2(x, b=2, c=0): [[8, 8], [9, 9], [10, 10]], index=Index([1, 2, 3]), columns=["foo1", "foo2"] ) tm.assert_frame_equal(result, expected) + + +def test_agg_groupings_selection(): + # GH#51186 - a selected grouping should be in the output of agg + df = DataFrame({"a": [1, 1, 2], "b": [3, 3, 4], "c": [5, 6, 7]}) + gb = df.groupby(["a", "b"]) + selected_gb = gb[["b", "c"]] + result = selected_gb.agg(lambda x: x.sum()) + index = MultiIndex( + levels=[[1, 2], [3, 4]], codes=[[0, 1], [0, 1]], names=["a", "b"] + ) + expected = DataFrame({"b": [6, 4], "c": [11, 7]}, index=index) + tm.assert_frame_equal(result, expected)