Skip to content

Commit 5e52303

Browse files
Backport PR pandas-dev#32040: BUG: GroupBy aggregation of DataFrame with MultiIndex columns breaks with custom function (pandas-dev#32648)
Co-authored-by: Marco Gorelli <[email protected]>
1 parent abadc4f commit 5e52303

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

doc/source/whatsnew/v1.0.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Fixed regressions
1717

1818
- Fixed regression in :meth:`DataFrame.to_excel` when ``columns`` kwarg is passed (:issue:`31677`)
1919
- Fixed regression in :meth:`Series.align` when ``other`` is a DataFrame and ``method`` is not None (:issue:`31785`)
20+
- Fixed regression in :meth:`groupby(..).agg() <pandas.core.groupby.GroupBy.agg>` which was failing on frames with MultiIndex columns and a custom function (:issue:`31777`)
2021
- Fixed regression in ``groupby(..).rolling(..).apply()`` (``RollingGroupby``) where the ``raw`` parameter was ignored (:issue:`31754`)
2122
- Fixed regression in :meth:`rolling(..).corr() <pandas.core.window.rolling.Rolling.corr>` when using a time offset (:issue:`31789`)
2223
- Fixed regression in :meth:`groupby(..).nunique() <pandas.core.groupby.DataFrameGroupBy.nunique>` which was modifying the original values if ``NaN`` values were present (:issue:`31950`)

pandas/core/groupby/generic.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -952,9 +952,11 @@ def aggregate(self, func=None, *args, **kwargs):
952952
raise
953953
result = self._aggregate_frame(func)
954954
else:
955-
result.columns = Index(
956-
result.columns.levels[0], name=self._selected_obj.columns.name
957-
)
955+
# select everything except for the last level, which is the one
956+
# containing the name of the function(s), see GH 32040
957+
result.columns = result.columns.rename(
958+
[self._selected_obj.columns.name] * result.columns.nlevels
959+
).droplevel(-1)
958960

959961
if not self.as_index:
960962
self._insert_inaxis_grouper_inplace(result)

pandas/tests/groupby/aggregate/test_aggregate.py

+13
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,19 @@ def test_agg_relabel_multiindex_duplicates():
692692
tm.assert_frame_equal(result, expected)
693693

694694

695+
@pytest.mark.parametrize(
696+
"func", [lambda s: s.mean(), lambda s: np.mean(s), lambda s: np.nanmean(s)]
697+
)
698+
def test_multiindex_custom_func(func):
699+
# GH 31777
700+
data = [[1, 4, 2], [5, 7, 1]]
701+
df = pd.DataFrame(data, columns=pd.MultiIndex.from_arrays([[1, 1, 2], [3, 4, 3]]))
702+
result = df.groupby(np.array([0, 1])).agg(func)
703+
expected_dict = {(1, 3): {0: 1, 1: 5}, (1, 4): {0: 4, 1: 7}, (2, 3): {0: 2, 1: 1}}
704+
expected = pd.DataFrame(expected_dict)
705+
tm.assert_frame_equal(result, expected)
706+
707+
695708
def myfunc(s):
696709
return np.percentile(s, q=0.90)
697710

0 commit comments

Comments
 (0)