Skip to content

Commit f57770c

Browse files
committed
ENH: first cut at margins for pivot_table. testing still needed, #114
1 parent 6a079a0 commit f57770c

File tree

3 files changed

+82
-4
lines changed

3 files changed

+82
-4
lines changed

pandas/core/index.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,26 @@ def swaplevel(self, i, j):
12021202
return MultiIndex(levels=new_levels, labels=new_labels,
12031203
names=new_names)
12041204

1205+
def reorder_levels(self, order):
1206+
"""
1207+
Rearrange levels using input order. May not drop or duplicate levels
1208+
1209+
Parameters
1210+
----------
1211+
"""
1212+
try:
1213+
assert(set(order) == set(range(self.nlevels)))
1214+
except AssertionError:
1215+
raise Exception('New order must be permutation of range(%d)' %
1216+
self.nlevels)
1217+
1218+
new_levels = [self.levels[i] for i in order]
1219+
new_labels = [self.labels[i] for i in order]
1220+
new_names = [self.names[i] for i in order]
1221+
1222+
return MultiIndex(levels=new_levels, labels=new_labels,
1223+
names=new_names)
1224+
12051225
def __getslice__(self, i, j):
12061226
return self.__getitem__(slice(i, j))
12071227

pandas/tools/pivot.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33

44
def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
5-
fill_value=None):
5+
fill_value=None, margins=False):
66
"""
77
Create a spreadsheet-style pivot table as a DataFrame. The levels in the
88
pivot table will be stored in MultiIndex objects (hierarchical indexes) on
@@ -19,6 +19,8 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
1919
aggfunc : function, default numpy.mean
2020
fill_value : scalar, default None
2121
Value to replace missing values with
22+
margins : boolean, default False
23+
Add all row / columns (e.g. for subtotal / grand totals)
2224
2325
Examples
2426
--------
@@ -59,15 +61,14 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
5961
else:
6062
values_multi = False
6163
values = [values]
64+
else:
65+
values = list(data.columns.drop(keys))
6266

6367
if values_passed:
6468
data = data[keys + values]
6569

6670
grouped = data.groupby(keys)
6771

68-
if values_passed and not values_multi:
69-
grouped = grouped[values[0]]
70-
7172
agged = grouped.agg(aggfunc)
7273

7374
table = agged
@@ -77,10 +78,61 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
7778
if fill_value is not None:
7879
table = table.fillna(value=fill_value)
7980

81+
if margins:
82+
table = _add_margins(table, data, values, rows=rows,
83+
cols=cols, aggfunc=aggfunc)
84+
85+
# discard the top level
86+
if values_passed and not values_multi:
87+
table = table[values[0]]
88+
8089
return table
8190

8291
DataFrame.pivot_table = pivot_table
8392

93+
def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
94+
if rows is not None:
95+
col_margin = data[rows + values].groupby(rows).agg(aggfunc)
96+
97+
# need to "interleave" the margins
98+
99+
table_pieces = []
100+
margin_keys = []
101+
for key, piece in table.groupby(level=0, axis=1):
102+
all_key = (key, 'All') + ('',) * (len(cols) - 1)
103+
piece[all_key] = col_margin[key]
104+
table_pieces.append(piece)
105+
margin_keys.append(all_key)
106+
107+
result = table_pieces[0]
108+
for piece in table_pieces[1:]:
109+
result = result.join(piece)
110+
else:
111+
result = table
112+
margin_keys = []
113+
114+
grand_margin = data[values].apply(aggfunc)
115+
116+
if cols is not None:
117+
row_margin = data[cols + values].groupby(cols).agg(aggfunc)
118+
row_margin = row_margin.stack()
119+
120+
# slight hack
121+
new_order = [len(cols)] + range(len(cols))
122+
row_margin.index = row_margin.index.reorder_levels(new_order)
123+
124+
key = ('All',) + ('',) * (len(rows) - 1)
125+
126+
row_margin = row_margin.reindex(result.columns)
127+
# populate grand margin
128+
for k in margin_keys:
129+
row_margin[k] = grand_margin[k[0]]
130+
131+
margin_dummy = DataFrame(row_margin, columns=[key]).T
132+
result = result.append(margin_dummy)
133+
134+
return result
135+
84136
def _convert_by(by):
85137
if by is None:
86138
by = []

pandas/tools/tests/test_pivot.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def test_pivot_table(self):
2727
cols= 'C'
2828
table = pivot_table(self.data, values='D', rows=rows, cols=cols)
2929

30+
table2 = self.data.pivot_table(values='D', rows=rows, cols=cols)
31+
assert_frame_equal(table, table2)
32+
3033
# this works
3134
pivot_table(self.data, values='D', rows=rows)
3235

@@ -57,6 +60,9 @@ def test_pivot_multi_values(self):
5760
rows='A', cols=['B', 'C'], fill_value=0)
5861
assert_frame_equal(result, expected)
5962

63+
def test_margins(self):
64+
pass
65+
6066
if __name__ == '__main__':
6167
import nose
6268
nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'],

0 commit comments

Comments
 (0)