Skip to content

Commit a99c057

Browse files
committed
ENH: can pass list of functions to pivot_table
1 parent 3440a9c commit a99c057

File tree

4 files changed

+59
-19
lines changed

4 files changed

+59
-19
lines changed

pandas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@
2929
from pandas.io.pytables import HDFStore
3030
from pandas.util.testing import debug
3131

32-
from pandas.tools.pivot import pivot_table
3332
from pandas.tools.merge import merge, concat
33+
from pandas.tools.pivot import pivot_table

pandas/tools/merge.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -934,25 +934,31 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None):
934934

935935
new_index = indexes[0]
936936
n = len(new_index)
937+
kpieces = len(indexes)
937938

938939
# also copies
939-
names = names + [indexes[0].name]
940+
new_names = list(names)
941+
new_levels = list(levels)
940942

941-
new_levels = levels
943+
# construct labels
944+
new_labels = []
942945

943946
# do something a bit more speedy
944-
new_levels.append(new_index)
945-
946-
# construct labels
947-
labels = []
948947

949-
for hlevel, level in zip(zipped, new_levels[:-1]):
948+
for hlevel, level in zip(zipped, levels):
950949
mapped = level.get_indexer(hlevel)
951-
labels.append(np.repeat(mapped, n))
950+
new_labels.append(np.repeat(mapped, n))
951+
952+
if isinstance(new_index, MultiIndex):
953+
new_levels.extend(new_index.levels)
954+
new_labels.extend([np.tile(lab, kpieces) for lab in new_index.labels])
955+
new_names.extend(new_index.names)
956+
else:
957+
new_levels.append(new_index)
958+
new_names.append(new_index.name)
959+
new_labels.append(np.tile(np.arange(n), kpieces))
952960

953-
# last labels for the new level
954-
labels.append(np.tile(np.arange(n), len(indexes)))
955-
return MultiIndex(levels=new_levels, labels=labels, names=names)
961+
return MultiIndex(levels=new_levels, labels=new_labels, names=new_names)
956962

957963
def _get_consensus_names(indexes):
958964
consensus_name = indexes[0].names

pandas/tools/pivot.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
# pylint: disable=E1103
2+
13
from pandas import Series, DataFrame
4+
from pandas.tools.merge import concat
25
import numpy as np
36

47
def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
@@ -16,7 +19,10 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
1619
Columns to group on the x-axis of the pivot table
1720
cols : list
1821
Columns to group on the x-axis of the pivot table
19-
aggfunc : function, default numpy.mean
22+
aggfunc : function, default numpy.mean, or list of functions
23+
If list of functions passed, the resulting pivot table will have
24+
hierarchical columns whose top level are the function names (inferred
25+
from the function objects themselves)
2026
fill_value : scalar, default None
2127
Value to replace missing values with
2228
margins : boolean, default False
@@ -52,6 +58,17 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
5258
rows = _convert_by(rows)
5359
cols = _convert_by(cols)
5460

61+
if isinstance(aggfunc, list):
62+
pieces = []
63+
keys = []
64+
for func in aggfunc:
65+
table = pivot_table(data, values=values, rows=rows, cols=cols,
66+
fill_value=fill_value, aggfunc=func,
67+
margins=margins)
68+
pieces.append(table)
69+
keys.append(func.__name__)
70+
return concat(pieces, keys=keys, axis=1)
71+
5572
keys = rows + cols
5673

5774
values_passed = values is not None
@@ -68,7 +85,6 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
6885
data = data[keys + values]
6986

7087
grouped = data.groupby(keys)
71-
7288
agged = grouped.agg(aggfunc)
7389

7490
table = agged
@@ -95,7 +111,6 @@ def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
95111
col_margin = data[rows + values].groupby(rows).agg(aggfunc)
96112

97113
# need to "interleave" the margins
98-
99114
table_pieces = []
100115
margin_keys = []
101116
for key, piece in table.groupby(level=0, axis=1):
@@ -104,9 +119,7 @@ def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
104119
table_pieces.append(piece)
105120
margin_keys.append(all_key)
106121

107-
result = table_pieces[0]
108-
for piece in table_pieces[1:]:
109-
result = result.join(piece)
122+
result = concat(table_pieces, axis=1)
110123
else:
111124
result = table
112125
margin_keys = table.columns

pandas/tools/tests/test_pivot.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from pandas import DataFrame
5+
from pandas import DataFrame, concat
66
from pandas.tools.pivot import pivot_table
77
import pandas.util.testing as tm
88

@@ -60,6 +60,27 @@ def test_pivot_multi_values(self):
6060
rows='A', cols=['B', 'C'], fill_value=0)
6161
tm.assert_frame_equal(result, expected)
6262

63+
def test_pivot_multi_functions(self):
64+
f = lambda func: pivot_table(self.data, values=['D', 'E'],
65+
rows=['A', 'B'], cols='C',
66+
aggfunc=func)
67+
result = f([np.mean, np.std])
68+
means = f(np.mean)
69+
stds = f(np.std)
70+
expected = concat([means, stds], keys=['mean', 'std'], axis=1)
71+
tm.assert_frame_equal(result, expected)
72+
73+
# margins not supported??
74+
f = lambda func: pivot_table(self.data, values=['D', 'E'],
75+
rows=['A', 'B'], cols='C',
76+
aggfunc=func, margins=True)
77+
result = f([np.mean, np.std])
78+
means = f(np.mean)
79+
stds = f(np.std)
80+
expected = concat([means, stds], keys=['mean', 'std'], axis=1)
81+
tm.assert_frame_equal(result, expected)
82+
83+
6384
def test_margins(self):
6485
def _check_output(res, col, rows=['A', 'B'], cols=['C']):
6586
cmarg = res['All'][:-1]

0 commit comments

Comments
 (0)