Skip to content

Commit 618110c

Browse files
committed
ENH: #3335 Pivot table support for setting name of margins column.
ref #3335. Adds margin_column parameter to pivot_table so that user can set it to something other than 'All'. Raises ValueError exception if there is a conflict between the value of margin_column and one of the other values appearing in the indices of the pivot table.
1 parent 031e3bc commit 618110c

File tree

2 files changed

+149
-73
lines changed

2 files changed

+149
-73
lines changed

pandas/tools/pivot.py

+52-19
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
# pylint: disable=E1103
22

3-
import warnings
43

54
from pandas import Series, DataFrame
65
from pandas.core.index import MultiIndex, Index
76
from pandas.core.groupby import Grouper
87
from pandas.tools.merge import concat
98
from pandas.tools.util import cartesian_product
109
from pandas.compat import range, lrange, zip
11-
from pandas.util.decorators import deprecate_kwarg
1210
from pandas import compat
1311
import pandas.core.common as com
1412
import numpy as np
1513

14+
DEFAULT_MARGIN_COLUMN_NAME = 'All'
15+
16+
1617
def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
17-
fill_value=None, margins=False, dropna=True):
18+
fill_value=None, margins=False, dropna=True,
19+
margins_column=DEFAULT_MARGIN_COLUMN_NAME):
1820
"""
1921
Create a spreadsheet-style pivot table as a DataFrame. The levels in the
2022
pivot table will be stored in MultiIndex objects (hierarchical indexes) on
@@ -40,6 +42,9 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
4042
Add all row / columns (e.g. for subtotal / grand totals)
4143
dropna : boolean, default True
4244
Do not include columns whose entries are all NaN
45+
margins_column : string, default 'All'
46+
Name of the row / column that will contain the totals
47+
when margins is True.
4348
4449
Examples
4550
--------
@@ -127,7 +132,7 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
127132
m = MultiIndex.from_arrays(cartesian_product(table.columns.levels))
128133
table = table.reindex_axis(m, axis=1)
129134
except AttributeError:
130-
pass # it's a single level or a series
135+
pass # it's a single level or a series
131136

132137
if isinstance(table, DataFrame):
133138
if isinstance(table.columns, MultiIndex):
@@ -140,7 +145,8 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
140145

141146
if margins:
142147
table = _add_margins(table, data, values, rows=index,
143-
cols=columns, aggfunc=aggfunc)
148+
cols=columns, aggfunc=aggfunc,
149+
margins_column=margins_column)
144150

145151
# discard the top level
146152
if values_passed and not values_multi:
@@ -155,28 +161,50 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
155161
DataFrame.pivot_table = pivot_table
156162

157163

158-
def _add_margins(table, data, values, rows, cols, aggfunc):
164+
def _add_margins(table, data, values, rows, cols, aggfunc,
165+
margins_column=DEFAULT_MARGIN_COLUMN_NAME):
166+
exception_message = 'Must choose different value for margins_column'
167+
for level in table.index.names:
168+
if margins_column in table.index.get_level_values(level):
169+
raise ValueError(exception_message)
170+
# could be passed a Series object with no 'columns'
171+
if hasattr(table, 'columns'):
172+
for level in table.columns.names[1:]:
173+
if margins_column in table.columns.get_level_values(level):
174+
raise ValueError(exception_message)
159175

160-
grand_margin = _compute_grand_margin(data, values, aggfunc)
176+
grand_margin = _compute_grand_margin(data, values, aggfunc, margins_column)
161177

162178
if not values and isinstance(table, Series):
163179
# If there are no values and the table is a series, then there is only
164180
# one column in the data. Compute grand margin and return it.
165-
row_key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
166-
return table.append(Series({row_key: grand_margin['All']}))
181+
182+
if len(rows) > 1:
183+
row_key = (margins_column,) + ('',) * (len(rows) - 1)
184+
else:
185+
row_key = margins_column
186+
187+
return table.append(Series({row_key: grand_margin[margins_column]}))
167188

168189
if values:
169-
marginal_result_set = _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin)
190+
marginal_result_set = _generate_marginal_results(table, data, values,
191+
rows, cols, aggfunc,
192+
grand_margin,
193+
margins_column)
170194
if not isinstance(marginal_result_set, tuple):
171195
return marginal_result_set
172196
result, margin_keys, row_margin = marginal_result_set
173197
else:
174-
marginal_result_set = _generate_marginal_results_without_values(table, data, rows, cols, aggfunc)
198+
marginal_result_set = _generate_marginal_results_without_values(
199+
table, data, rows, cols, aggfunc, margins_column)
175200
if not isinstance(marginal_result_set, tuple):
176201
return marginal_result_set
177202
result, margin_keys, row_margin = marginal_result_set
178203

179-
key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
204+
if len(rows) > 1:
205+
key = (margins_column,) + ('',) * (len(rows) - 1)
206+
else:
207+
key = margins_column
180208

181209
row_margin = row_margin.reindex(result.columns)
182210
# populate grand margin
@@ -195,7 +223,8 @@ def _add_margins(table, data, values, rows, cols, aggfunc):
195223
return result
196224

197225

198-
def _compute_grand_margin(data, values, aggfunc):
226+
def _compute_grand_margin(data, values, aggfunc,
227+
margins_column=DEFAULT_MARGIN_COLUMN_NAME):
199228

200229
if values:
201230
grand_margin = {}
@@ -214,17 +243,19 @@ def _compute_grand_margin(data, values, aggfunc):
214243
pass
215244
return grand_margin
216245
else:
217-
return {'All': aggfunc(data.index)}
246+
return {margins_column: aggfunc(data.index)}
218247

219248

220-
def _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin):
249+
def _generate_marginal_results(table, data, values, rows, cols, aggfunc,
250+
grand_margin,
251+
margins_column=DEFAULT_MARGIN_COLUMN_NAME):
221252
if len(cols) > 0:
222253
# need to "interleave" the margins
223254
table_pieces = []
224255
margin_keys = []
225256

226257
def _all_key(key):
227-
return (key, 'All') + ('',) * (len(cols) - 1)
258+
return (key, margins_column) + ('',) * (len(cols) - 1)
228259

229260
if len(rows) > 0:
230261
margin = data[rows + values].groupby(rows).agg(aggfunc)
@@ -269,15 +300,17 @@ def _all_key(key):
269300
return result, margin_keys, row_margin
270301

271302

272-
def _generate_marginal_results_without_values(table, data, rows, cols, aggfunc):
303+
def _generate_marginal_results_without_values(
304+
table, data, rows, cols, aggfunc,
305+
margins_column=DEFAULT_MARGIN_COLUMN_NAME):
273306
if len(cols) > 0:
274307
# need to "interleave" the margins
275308
margin_keys = []
276309

277310
def _all_key():
278311
if len(cols) == 1:
279-
return 'All'
280-
return ('All', ) + ('', ) * (len(cols) - 1)
312+
return margins_column
313+
return (margins_column, ) + ('', ) * (len(cols) - 1)
281314

282315
if len(rows) > 0:
283316
margin = data[rows].groupby(rows).apply(aggfunc)

pandas/tools/tests/test_pivot.py

+97-54
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pandas import DataFrame, Series, Index, MultiIndex, Grouper
88
from pandas.tools.merge import concat
99
from pandas.tools.pivot import pivot_table, crosstab
10+
from pandas.tools.pivot import DEFAULT_MARGIN_COLUMN_NAME
1011
from pandas.compat import range, u, product
1112
import pandas.util.testing as tm
1213

@@ -224,82 +225,106 @@ def test_pivot_with_tz(self):
224225
tm.assert_frame_equal(pv, expected)
225226

226227
def test_margins(self):
227-
def _check_output(res, col, index=['A', 'B'], columns=['C']):
228-
cmarg = res['All'][:-1]
229-
exp = self.data.groupby(index)[col].mean()
230-
tm.assert_series_equal(cmarg, exp, check_names=False)
231-
self.assertEqual(cmarg.name, 'All')
232-
233-
res = res.sortlevel()
234-
rmarg = res.xs(('All', ''))[:-1]
235-
exp = self.data.groupby(columns)[col].mean()
236-
tm.assert_series_equal(rmarg, exp, check_names=False)
237-
self.assertEqual(rmarg.name, ('All', ''))
238-
239-
gmarg = res['All']['All', '']
240-
exp = self.data[col].mean()
241-
self.assertEqual(gmarg, exp)
228+
def _check_output(result, values_col, index=['A', 'B'],
229+
columns=['C'],
230+
margins_col=DEFAULT_MARGIN_COLUMN_NAME):
231+
col_margins = result.ix[:-1, margins_col]
232+
expected_col_margins = self.data.groupby(index)[values_col].mean()
233+
tm.assert_series_equal(col_margins, expected_col_margins,
234+
check_names=False)
235+
self.assertEqual(col_margins.name, margins_col)
236+
237+
result = result.sortlevel()
238+
index_margins = result.ix[(margins_col, '')].iloc[:-1]
239+
expected_ix_margins = self.data.groupby(columns)[values_col].mean()
240+
tm.assert_series_equal(index_margins, expected_ix_margins,
241+
check_names=False)
242+
self.assertEqual(index_margins.name, (margins_col, ''))
243+
244+
grand_total_margins = result.loc[(margins_col, ''), margins_col]
245+
expected_total_margins = self.data[values_col].mean()
246+
self.assertEqual(grand_total_margins, expected_total_margins)
242247

243248
# column specified
244-
table = self.data.pivot_table('D', index=['A', 'B'], columns='C',
245-
margins=True, aggfunc=np.mean)
246-
_check_output(table, 'D')
249+
result = self.data.pivot_table(values='D', index=['A', 'B'],
250+
columns='C',
251+
margins=True, aggfunc=np.mean)
252+
_check_output(result, 'D')
253+
254+
# Set a different margins_column (not 'All')
255+
result = self.data.pivot_table(values='D', index=['A', 'B'],
256+
columns='C',
257+
margins=True, aggfunc=np.mean,
258+
margins_column='Totals')
259+
_check_output(result, 'D', margins_col='Totals')
247260

248261
# no column specified
249262
table = self.data.pivot_table(index=['A', 'B'], columns='C',
250263
margins=True, aggfunc=np.mean)
251-
for valcol in table.columns.levels[0]:
252-
_check_output(table[valcol], valcol)
264+
for value_col in table.columns.levels[0]:
265+
_check_output(table[value_col], value_col)
253266

254267
# no col
255268

256269
# to help with a buglet
257270
self.data.columns = [k * 2 for k in self.data.columns]
258271
table = self.data.pivot_table(index=['AA', 'BB'], margins=True,
259272
aggfunc=np.mean)
260-
for valcol in table.columns:
261-
gmarg = table[valcol]['All', '']
262-
self.assertEqual(gmarg, self.data[valcol].mean())
263-
264-
# this is OK
265-
table = self.data.pivot_table(index=['AA', 'BB'], margins=True,
266-
aggfunc='mean')
273+
for value_col in table.columns:
274+
totals = table.loc[(DEFAULT_MARGIN_COLUMN_NAME, ''), value_col]
275+
self.assertEqual(totals, self.data[value_col].mean())
267276

268277
# no rows
269278
rtable = self.data.pivot_table(columns=['AA', 'BB'], margins=True,
270279
aggfunc=np.mean)
271280
tm.assert_isinstance(rtable, Series)
281+
282+
table = self.data.pivot_table(index=['AA', 'BB'], margins=True,
283+
aggfunc='mean')
272284
for item in ['DD', 'EE', 'FF']:
273-
gmarg = table[item]['All', '']
274-
self.assertEqual(gmarg, self.data[item].mean())
285+
totals = table.loc[(DEFAULT_MARGIN_COLUMN_NAME, ''), item]
286+
self.assertEqual(totals, self.data[item].mean())
275287

276288
# issue number #8349: pivot_table with margins and dictionary aggfunc
289+
data = [
290+
{'JOB': 'Worker', 'NAME': 'Bob', 'YEAR': 2013,
291+
'MONTH': 12, 'DAYS': 3, 'SALARY': 17},
292+
{'JOB': 'Employ', 'NAME':
293+
'Mary', 'YEAR': 2013, 'MONTH': 12, 'DAYS': 5, 'SALARY': 23},
294+
{'JOB': 'Worker', 'NAME': 'Bob', 'YEAR': 2014,
295+
'MONTH': 1, 'DAYS': 10, 'SALARY': 100},
296+
{'JOB': 'Worker', 'NAME': 'Bob', 'YEAR': 2014,
297+
'MONTH': 1, 'DAYS': 11, 'SALARY': 110},
298+
{'JOB': 'Employ', 'NAME': 'Mary', 'YEAR': 2014,
299+
'MONTH': 1, 'DAYS': 15, 'SALARY': 200},
300+
{'JOB': 'Worker', 'NAME': 'Bob', 'YEAR': 2014,
301+
'MONTH': 2, 'DAYS': 8, 'SALARY': 80},
302+
{'JOB': 'Employ', 'NAME': 'Mary', 'YEAR': 2014,
303+
'MONTH': 2, 'DAYS': 5, 'SALARY': 190},
304+
]
277305

278-
df=DataFrame([ {'JOB':'Worker','NAME':'Bob' ,'YEAR':2013,'MONTH':12,'DAYS': 3,'SALARY': 17},
279-
{'JOB':'Employ','NAME':'Mary','YEAR':2013,'MONTH':12,'DAYS': 5,'SALARY': 23},
280-
{'JOB':'Worker','NAME':'Bob' ,'YEAR':2014,'MONTH': 1,'DAYS':10,'SALARY':100},
281-
{'JOB':'Worker','NAME':'Bob' ,'YEAR':2014,'MONTH': 1,'DAYS':11,'SALARY':110},
282-
{'JOB':'Employ','NAME':'Mary','YEAR':2014,'MONTH': 1,'DAYS':15,'SALARY':200},
283-
{'JOB':'Worker','NAME':'Bob' ,'YEAR':2014,'MONTH': 2,'DAYS': 8,'SALARY': 80},
284-
{'JOB':'Employ','NAME':'Mary','YEAR':2014,'MONTH': 2,'DAYS': 5,'SALARY':190} ])
285-
286-
df=df.set_index(['JOB','NAME','YEAR','MONTH'],drop=False,append=False)
287-
288-
rs=df.pivot_table( index=['JOB','NAME'],
289-
columns=['YEAR','MONTH'],
290-
values=['DAYS','SALARY'],
291-
aggfunc={'DAYS':'mean','SALARY':'sum'},
292-
margins=True)
306+
df = DataFrame(data)
293307

294-
ex=df.pivot_table(index=['JOB','NAME'],columns=['YEAR','MONTH'],values=['DAYS'],aggfunc='mean',margins=True)
308+
df = df.set_index(['JOB', 'NAME', 'YEAR', 'MONTH'], drop=False,
309+
append=False)
295310

296-
tm.assert_frame_equal(rs['DAYS'], ex['DAYS'])
311+
result = df.pivot_table(index=['JOB', 'NAME'],
312+
columns=['YEAR', 'MONTH'],
313+
values=['DAYS', 'SALARY'],
314+
aggfunc={'DAYS': 'mean', 'SALARY': 'sum'},
315+
margins=True)
297316

298-
ex=df.pivot_table(index=['JOB','NAME'],columns=['YEAR','MONTH'],values=['SALARY'],aggfunc='sum',margins=True)
317+
expected = df.pivot_table(index=['JOB', 'NAME'],
318+
columns=['YEAR', 'MONTH'], values=['DAYS'],
319+
aggfunc='mean', margins=True)
299320

300-
tm.assert_frame_equal(rs['SALARY'], ex['SALARY'])
321+
tm.assert_frame_equal(result['DAYS'], expected['DAYS'])
301322

323+
expected = df.pivot_table(index=['JOB', 'NAME'],
324+
columns=['YEAR', 'MONTH'], values=['SALARY'],
325+
aggfunc='sum', margins=True)
302326

327+
tm.assert_frame_equal(result['SALARY'], expected['SALARY'])
303328

304329
def test_pivot_integer_columns(self):
305330
# caused by upstream bug in unstack
@@ -402,6 +427,24 @@ def test_margins_no_values_two_row_two_cols(self):
402427
result = self.data[['A', 'B', 'C', 'D']].pivot_table(index=['A', 'B'], columns=['C', 'D'], aggfunc=len, margins=True)
403428
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])
404429

430+
def test_pivot_table_with_margins_set_margin_column(self):
431+
for margin_column in ['foo', 'one']:
432+
with self.assertRaises(ValueError):
433+
# multi-index index
434+
pivot_table(self.data, values='D', index=['A', 'B'],
435+
columns=['C'], margins=True,
436+
margins_column=margin_column)
437+
with self.assertRaises(ValueError):
438+
# multi-index column
439+
pivot_table(self.data, values='D', index=['C'],
440+
columns=['A', 'B'], margins=True,
441+
margins_column=margin_column)
442+
with self.assertRaises(ValueError):
443+
# non-multi-index index/column
444+
pivot_table(self.data, values='D', index=['A'],
445+
columns=['B'], margins=True,
446+
margins_column=margin_column)
447+
405448
def test_pivot_timegrouper(self):
406449
df = DataFrame({
407450
'Branch' : 'A A A A A A A B'.split(),
@@ -678,17 +721,17 @@ def test_crosstab_margins(self):
678721
self.assertEqual(result.index.names, ('a',))
679722
self.assertEqual(result.columns.names, ['b', 'c'])
680723

681-
all_cols = result['All', '']
724+
all_cols = result[DEFAULT_MARGIN_COLUMN_NAME, '']
682725
exp_cols = df.groupby(['a']).size().astype('i8')
683-
exp_cols = exp_cols.append(Series([len(df)], index=['All']))
684-
exp_cols.name = ('All', '')
726+
exp_cols = exp_cols.append(Series([len(df)], index=[DEFAULT_MARGIN_COLUMN_NAME]))
727+
exp_cols.name = (DEFAULT_MARGIN_COLUMN_NAME, '')
685728

686729
tm.assert_series_equal(all_cols, exp_cols)
687730

688-
all_rows = result.ix['All']
731+
all_rows = result.ix[DEFAULT_MARGIN_COLUMN_NAME]
689732
exp_rows = df.groupby(['b', 'c']).size().astype('i8')
690-
exp_rows = exp_rows.append(Series([len(df)], index=[('All', '')]))
691-
exp_rows.name = 'All'
733+
exp_rows = exp_rows.append(Series([len(df)], index=[(DEFAULT_MARGIN_COLUMN_NAME, '')]))
734+
exp_rows.name = DEFAULT_MARGIN_COLUMN_NAME
692735

693736
exp_rows = exp_rows.reindex(all_rows.index)
694737
exp_rows = exp_rows.fillna(0).astype(np.int64)

0 commit comments

Comments
 (0)