diff --git a/doc/source/whatsnew/v0.20.3.txt b/doc/source/whatsnew/v0.20.3.txt index 52f7701724f18..8d145033ddf89 100644 --- a/doc/source/whatsnew/v0.20.3.txt +++ b/doc/source/whatsnew/v0.20.3.txt @@ -65,6 +65,7 @@ Plotting Groupby/Resample/Rolling ^^^^^^^^^^^^^^^^^^^^^^^^ +- Bug in groupby logic causing MultiIndex column levels to be lost (:issue:`16231`) Sparse diff --git a/pandas/core/base.py b/pandas/core/base.py index 97c4c8626dcbb..d7b6acc021dcd 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -658,6 +658,7 @@ def _aggregate_multiple_funcs(self, arg, _level, _axis): # degenerate case if obj.ndim == 1: + names = obj.index.names for a in arg: try: colg = self._gotitem(obj.name, ndim=1, subset=obj) @@ -673,6 +674,7 @@ def _aggregate_multiple_funcs(self, arg, _level, _axis): # multiples else: + names = obj.columns.names for col in obj: try: colg = self._gotitem(col, ndim=1, subset=obj[col]) @@ -691,7 +693,7 @@ def _aggregate_multiple_funcs(self, arg, _level, _axis): raise ValueError("no results") try: - return concat(results, keys=keys, axis=1) + return concat(results, keys=keys, axis=1, names=names) except TypeError: # we are concatting non-NDFrame objects, diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 9d6d2297f6ea0..ace3bfc7948ed 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -3481,9 +3481,9 @@ def aggregate(self, arg, *args, **kwargs): assert not args and not kwargs result = self._aggregate_multiple_funcs( [arg], _level=_level, _axis=self.axis) - result.columns = Index( - result.columns.levels[0], - name=self._selected_obj.columns.name) + result.columns = result.columns.droplevel(-1) + if result.columns.nlevels == 1: + result.columns.name = self._selected_obj.columns.name except: result = self._aggregate_generic(arg, *args, **kwargs) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index 19124a33bdbcb..5a8cd7d4b4f05 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -3626,6 +3626,22 @@ def test_func(x): tm.assert_frame_equal(result1, expected1) tm.assert_frame_equal(result2, expected2) + @pytest.mark.parametrize('nlevel', range(1, 6)) + @pytest.mark.parametrize('as_index', [False, True]) + def test_groupby_aggregate_preserves_multiindex_columns(self, nlevel, + as_index): + # GH 16231 + cols = pd.MultiIndex.from_tuples([[i] * nlevel for i in range(2)], + names=['lev_{}'.format(lev) + for lev in range(nlevel)]) + df = pd.DataFrame(np.random.randn(10, len(cols)), columns=cols) + + grouped = df.groupby(df.index % 3, as_index=as_index) + via_direct = grouped.sum() + via_agg = grouped.aggregate(lambda x: x.sum()) + + tm.assert_frame_equal(via_direct, via_agg) + def test_groupby_preserves_sort(self): # Test to ensure that groupby always preserves sort order of original # object. Issue #8588 and #9651 diff --git a/pandas/tests/test_resample.py b/pandas/tests/test_resample.py index 959e3d2f459ce..cab5ec1c64fc1 100644 --- a/pandas/tests/test_resample.py +++ b/pandas/tests/test_resample.py @@ -2972,6 +2972,37 @@ def f(x): result = g.apply(f) assert_frame_equal(result, expected) + def test_apply_preserves_multiindex_columns(self): + # GH 16231 + # the original failing case + cols = pd.MultiIndex.from_tuples([('A', 'a', '', 'one'), + ('B', 'b', 'i', 'two')]) + ind = pd.DatetimeIndex(start='2017-01-01', freq='15Min', periods=8) + df = pd.DataFrame(np.random.randn(8, 2), index=ind, columns=cols) + + agg_dict = {col: (np.sum if col[3] == 'one' else np.mean) + for col in df.columns} + resampled = df.resample('H').apply(lambda x: agg_dict[x.name](x)) + assert isinstance(resampled.columns, pd.MultiIndex) + + @pytest.mark.parametrize('nlevel', range(1, 6)) + @pytest.mark.parametrize('ncol', [1, 2]) + @pytest.mark.parametrize('freq', ['D', '360Min']) + def test_apply_preserves_multiindex_columns_grid(self, nlevel, ncol, freq): + # GH 16231 + cols = pd.MultiIndex.from_tuples([[i] * nlevel for i in range(ncol)], + names=['lev_{}'.format(lev) + for lev in range(nlevel)]) + idx = pd.date_range('2000-01-01', freq="H", periods=50) + df = pd.DataFrame(np.random.randn(len(idx), len(cols)), + columns=cols, index=idx) + + resampled = df.resample(freq) + + via_direct = resampled.sum() + via_apply = resampled.apply(lambda x: x.sum()) + tm.assert_frame_equal(via_direct, via_apply) + def test_resample_groupby_with_label(self): # GH 13235 index = date_range('2000-01-01', freq='2D', periods=5)