Skip to content

Commit 078a0ce

Browse files
committed
BUG: keep column levels when using apply after grouping (pandas-dev#16231)
1 parent ceaf852 commit 078a0ce

File tree

5 files changed

+54
-4
lines changed

5 files changed

+54
-4
lines changed

doc/source/whatsnew/v0.20.3.txt

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ Plotting
6565
Groupby/Resample/Rolling
6666
^^^^^^^^^^^^^^^^^^^^^^^^
6767

68+
- Bug in groupby logic causing MultiIndex column levels to be lost (:issue:`16231`)
6869

6970

7071
Sparse

pandas/core/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def _aggregate_multiple_funcs(self, arg, _level, _axis):
658658

659659
# degenerate case
660660
if obj.ndim == 1:
661+
names = obj.index.names
661662
for a in arg:
662663
try:
663664
colg = self._gotitem(obj.name, ndim=1, subset=obj)
@@ -673,6 +674,7 @@ def _aggregate_multiple_funcs(self, arg, _level, _axis):
673674

674675
# multiples
675676
else:
677+
names = obj.columns.names
676678
for col in obj:
677679
try:
678680
colg = self._gotitem(col, ndim=1, subset=obj[col])
@@ -691,7 +693,7 @@ def _aggregate_multiple_funcs(self, arg, _level, _axis):
691693
raise ValueError("no results")
692694

693695
try:
694-
return concat(results, keys=keys, axis=1)
696+
return concat(results, keys=keys, axis=1, names=names)
695697
except TypeError:
696698

697699
# we are concatting non-NDFrame objects,

pandas/core/groupby.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3481,9 +3481,9 @@ def aggregate(self, arg, *args, **kwargs):
34813481
assert not args and not kwargs
34823482
result = self._aggregate_multiple_funcs(
34833483
[arg], _level=_level, _axis=self.axis)
3484-
result.columns = Index(
3485-
result.columns.levels[0],
3486-
name=self._selected_obj.columns.name)
3484+
result.columns = result.columns.droplevel(-1)
3485+
if result.columns.nlevels == 1:
3486+
result.columns.name = self._selected_obj.columns.name
34873487
except:
34883488
result = self._aggregate_generic(arg, *args, **kwargs)
34893489

pandas/tests/groupby/test_groupby.py

+16
Original file line numberDiff line numberDiff line change
@@ -3626,6 +3626,22 @@ def test_func(x):
36263626
tm.assert_frame_equal(result1, expected1)
36273627
tm.assert_frame_equal(result2, expected2)
36283628

3629+
@pytest.mark.parametrize('nlevel', range(1, 6))
3630+
@pytest.mark.parametrize('as_index', [False, True])
3631+
def test_groupby_aggregate_preserves_multiindex_columns(self, nlevel,
3632+
as_index):
3633+
# GH 16231
3634+
cols = pd.MultiIndex.from_tuples([[i] * nlevel for i in range(2)],
3635+
names=['lev_{}'.format(lev)
3636+
for lev in range(nlevel)])
3637+
df = pd.DataFrame(np.random.randn(10, len(cols)), columns=cols)
3638+
3639+
grouped = df.groupby(df.index % 3, as_index=as_index)
3640+
via_direct = grouped.sum()
3641+
via_agg = grouped.aggregate(lambda x: x.sum())
3642+
3643+
tm.assert_frame_equal(via_direct, via_agg)
3644+
36293645
def test_groupby_preserves_sort(self):
36303646
# Test to ensure that groupby always preserves sort order of original
36313647
# object. Issue #8588 and #9651

pandas/tests/test_resample.py

+31
Original file line numberDiff line numberDiff line change
@@ -2972,6 +2972,37 @@ def f(x):
29722972
result = g.apply(f)
29732973
assert_frame_equal(result, expected)
29742974

2975+
def test_apply_preserves_multiindex_columns(self):
2976+
# GH 16231
2977+
# the original failing case
2978+
cols = pd.MultiIndex.from_tuples([('A', 'a', '', 'one'),
2979+
('B', 'b', 'i', 'two')])
2980+
ind = pd.DatetimeIndex(start='2017-01-01', freq='15Min', periods=8)
2981+
df = pd.DataFrame(np.random.randn(8, 2), index=ind, columns=cols)
2982+
2983+
agg_dict = {col: (np.sum if col[3] == 'one' else np.mean)
2984+
for col in df.columns}
2985+
resampled = df.resample('H').apply(lambda x: agg_dict[x.name](x))
2986+
assert isinstance(resampled.columns, pd.MultiIndex)
2987+
2988+
@pytest.mark.parametrize('nlevel', range(1, 6))
2989+
@pytest.mark.parametrize('ncol', [1, 2])
2990+
@pytest.mark.parametrize('freq', ['D', '360Min'])
2991+
def test_apply_preserves_multiindex_columns_grid(self, nlevel, ncol, freq):
2992+
# GH 16231
2993+
cols = pd.MultiIndex.from_tuples([[i] * nlevel for i in range(ncol)],
2994+
names=['lev_{}'.format(lev)
2995+
for lev in range(nlevel)])
2996+
idx = pd.date_range('2000-01-01', freq="H", periods=50)
2997+
df = pd.DataFrame(np.random.randn(len(idx), len(cols)),
2998+
columns=cols, index=idx)
2999+
3000+
resampled = df.resample(freq)
3001+
3002+
via_direct = resampled.sum()
3003+
via_apply = resampled.apply(lambda x: x.sum())
3004+
tm.assert_frame_equal(via_direct, via_apply)
3005+
29753006
def test_resample_groupby_with_label(self):
29763007
# GH 13235
29773008
index = date_range('2000-01-01', freq='2D', periods=5)

0 commit comments

Comments
 (0)