Skip to content

Commit c239f54

Browse files
authored
BUG: groupby.agg doesn't include grouping columns in result when selected (#51398)
1 parent 129108f commit c239f54

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,7 @@ Groupby/resample/rolling
13771377
- Bug in :meth:`.DataFrameGroupBy.agg` with ``engine="numba"`` failing to respect ``as_index=False`` (:issue:`51228`)
13781378
- Bug in :meth:`.DataFrameGroupBy.agg`, :meth:`.SeriesGroupBy.agg`, and :meth:`.Resampler.agg` would ignore arguments when passed a list of functions (:issue:`50863`)
13791379
- Bug in :meth:`.DataFrameGroupBy.ohlc` ignoring ``as_index=False`` (:issue:`51413`)
1380+
- Bug in :meth:`DataFrameGroupBy.agg` after subsetting columns (e.g. ``.groupby(...)[["a", "b"]]``) would not include groupings in the result (:issue:`51186`)
13801381

13811382
Reshaping
13821383
^^^^^^^^^

pandas/core/groupby/generic.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -1359,21 +1359,15 @@ def _python_agg_general(self, func, *args, **kwargs):
13591359
return self._wrap_aggregated_output(res)
13601360

13611361
def _iterate_slices(self) -> Iterable[Series]:
1362-
obj = self._selected_obj
1362+
obj = self._obj_with_exclusions
13631363
if self.axis == 1:
13641364
obj = obj.T
13651365

1366-
if isinstance(obj, Series) and obj.name not in self.exclusions:
1366+
if isinstance(obj, Series):
13671367
# Occurs when doing DataFrameGroupBy(...)["X"]
13681368
yield obj
13691369
else:
13701370
for label, values in obj.items():
1371-
if label in self.exclusions:
1372-
# Note: if we tried to just iterate over _obj_with_exclusions,
1373-
# we would break test_wrap_agg_out by yielding a column
1374-
# that is skipped here but not dropped from obj_with_exclusions
1375-
continue
1376-
13771371
yield values
13781372

13791373
def _aggregate_frame(self, func, *args, **kwargs) -> DataFrame:

pandas/tests/groupby/aggregate/test_aggregate.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,8 @@ def func(ser):
321321

322322
with pytest.raises(TypeError, match="Test error message"):
323323
grouped.aggregate(func)
324-
result = grouped[[c for c in three_group if c != "C"]].aggregate(func)
325-
exp_grouped = three_group.loc[:, three_group.columns != "C"]
324+
result = grouped[["D", "E", "F"]].aggregate(func)
325+
exp_grouped = three_group.loc[:, ["A", "B", "D", "E", "F"]]
326326
expected = exp_grouped.groupby(["A", "B"]).aggregate(func)
327327
tm.assert_frame_equal(result, expected)
328328

@@ -1521,3 +1521,16 @@ def foo2(x, b=2, c=0):
15211521
[[8, 8], [9, 9], [10, 10]], index=Index([1, 2, 3]), columns=["foo1", "foo2"]
15221522
)
15231523
tm.assert_frame_equal(result, expected)
1524+
1525+
1526+
def test_agg_groupings_selection():
1527+
# GH#51186 - a selected grouping should be in the output of agg
1528+
df = DataFrame({"a": [1, 1, 2], "b": [3, 3, 4], "c": [5, 6, 7]})
1529+
gb = df.groupby(["a", "b"])
1530+
selected_gb = gb[["b", "c"]]
1531+
result = selected_gb.agg(lambda x: x.sum())
1532+
index = MultiIndex(
1533+
levels=[[1, 2], [3, 4]], codes=[[0, 1], [0, 1]], names=["a", "b"]
1534+
)
1535+
expected = DataFrame({"b": [6, 4], "c": [11, 7]}, index=index)
1536+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)