Skip to content

Commit 10fe47e

Browse files
committed
Merge pull request #11581 from lexual/issue_3335_pivot_handle_all_for_margins
ENH: #3335 Pivot table support for setting name of margins column.
2 parents 96c1f63 + 1ca006c commit 10fe47e

File tree

3 files changed

+142
-69
lines changed

3 files changed

+142
-69
lines changed

doc/source/whatsnew/v0.17.1.txt

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ Enhancements
4646

4747
pd.Index([1, np.nan, 3]).fillna(2)
4848

49+
- ``pivot_table`` now has a ``margins_name`` argument so you can use something other than the default of 'All' (:issue:`3335`)
50+
4951
.. _whatsnew_0171.api:
5052

5153
API changes

pandas/tools/pivot.py

+49-21
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
# pylint: disable=E1103
22

3-
import warnings
43

54
from pandas import Series, DataFrame
65
from pandas.core.index import MultiIndex, Index
76
from pandas.core.groupby import Grouper
87
from pandas.tools.merge import concat
98
from pandas.tools.util import cartesian_product
109
from pandas.compat import range, lrange, zip
11-
from pandas.util.decorators import deprecate_kwarg
1210
from pandas import compat
1311
import pandas.core.common as com
1412
import numpy as np
1513

14+
1615
def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
17-
fill_value=None, margins=False, dropna=True):
16+
fill_value=None, margins=False, dropna=True,
17+
margins_name='All'):
1818
"""
1919
Create a spreadsheet-style pivot table as a DataFrame. The levels in the
2020
pivot table will be stored in MultiIndex objects (hierarchical indexes) on
@@ -40,6 +40,9 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
4040
Add all row / columns (e.g. for subtotal / grand totals)
4141
dropna : boolean, default True
4242
Do not include columns whose entries are all NaN
43+
margins_name : string, default 'All'
44+
Name of the row / column that will contain the totals
45+
when margins is True.
4346
4447
Examples
4548
--------
@@ -127,7 +130,7 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
127130
m = MultiIndex.from_arrays(cartesian_product(table.columns.levels))
128131
table = table.reindex_axis(m, axis=1)
129132
except AttributeError:
130-
pass # it's a single level or a series
133+
pass # it's a single level or a series
131134

132135
if isinstance(table, DataFrame):
133136
if isinstance(table.columns, MultiIndex):
@@ -140,7 +143,8 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
140143

141144
if margins:
142145
table = _add_margins(table, data, values, rows=index,
143-
cols=columns, aggfunc=aggfunc)
146+
cols=columns, aggfunc=aggfunc,
147+
margins_name=margins_name)
144148

145149
# discard the top level
146150
if values_passed and not values_multi:
@@ -155,29 +159,49 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
155159
DataFrame.pivot_table = pivot_table
156160

157161

158-
def _add_margins(table, data, values, rows, cols, aggfunc):
162+
def _add_margins(table, data, values, rows, cols, aggfunc,
163+
margins_name='All'):
164+
if not isinstance(margins_name, compat.string_types):
165+
raise ValueError('margins_name argument must be a string')
166+
167+
exception_msg = 'Conflicting name "{0}" in margins'.format(margins_name)
168+
for level in table.index.names:
169+
if margins_name in table.index.get_level_values(level):
170+
raise ValueError(exception_msg)
159171

160-
grand_margin = _compute_grand_margin(data, values, aggfunc)
172+
grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name)
173+
174+
# could be passed a Series object with no 'columns'
175+
if hasattr(table, 'columns'):
176+
for level in table.columns.names[1:]:
177+
if margins_name in table.columns.get_level_values(level):
178+
raise ValueError(exception_msg)
179+
180+
if len(rows) > 1:
181+
key = (margins_name,) + ('',) * (len(rows) - 1)
182+
else:
183+
key = margins_name
161184

162185
if not values and isinstance(table, Series):
163186
# If there are no values and the table is a series, then there is only
164187
# one column in the data. Compute grand margin and return it.
165-
row_key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
166-
return table.append(Series({row_key: grand_margin['All']}))
188+
return table.append(Series({key: grand_margin[margins_name]}))
167189

168190
if values:
169-
marginal_result_set = _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin)
191+
marginal_result_set = _generate_marginal_results(table, data, values,
192+
rows, cols, aggfunc,
193+
grand_margin,
194+
margins_name)
170195
if not isinstance(marginal_result_set, tuple):
171196
return marginal_result_set
172197
result, margin_keys, row_margin = marginal_result_set
173198
else:
174-
marginal_result_set = _generate_marginal_results_without_values(table, data, rows, cols, aggfunc)
199+
marginal_result_set = _generate_marginal_results_without_values(
200+
table, data, rows, cols, aggfunc, margins_name)
175201
if not isinstance(marginal_result_set, tuple):
176202
return marginal_result_set
177203
result, margin_keys, row_margin = marginal_result_set
178204

179-
key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
180-
181205
row_margin = row_margin.reindex(result.columns)
182206
# populate grand margin
183207
for k in margin_keys:
@@ -201,7 +225,8 @@ def _add_margins(table, data, values, rows, cols, aggfunc):
201225
return result
202226

203227

204-
def _compute_grand_margin(data, values, aggfunc):
228+
def _compute_grand_margin(data, values, aggfunc,
229+
margins_name='All'):
205230

206231
if values:
207232
grand_margin = {}
@@ -220,18 +245,19 @@ def _compute_grand_margin(data, values, aggfunc):
220245
pass
221246
return grand_margin
222247
else:
223-
return {'All': aggfunc(data.index)}
224-
248+
return {margins_name: aggfunc(data.index)}
225249

226-
def _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin):
227250

251+
def _generate_marginal_results(table, data, values, rows, cols, aggfunc,
252+
grand_margin,
253+
margins_name='All'):
228254
if len(cols) > 0:
229255
# need to "interleave" the margins
230256
table_pieces = []
231257
margin_keys = []
232258

233259
def _all_key(key):
234-
return (key, 'All') + ('',) * (len(cols) - 1)
260+
return (key, margins_name) + ('',) * (len(cols) - 1)
235261

236262
if len(rows) > 0:
237263
margin = data[rows + values].groupby(rows).agg(aggfunc)
@@ -282,15 +308,17 @@ def _all_key(key):
282308
return result, margin_keys, row_margin
283309

284310

285-
def _generate_marginal_results_without_values(table, data, rows, cols, aggfunc):
311+
def _generate_marginal_results_without_values(
312+
table, data, rows, cols, aggfunc,
313+
margins_name='All'):
286314
if len(cols) > 0:
287315
# need to "interleave" the margins
288316
margin_keys = []
289317

290318
def _all_key():
291319
if len(cols) == 1:
292-
return 'All'
293-
return ('All', ) + ('', ) * (len(cols) - 1)
320+
return margins_name
321+
return (margins_name, ) + ('', ) * (len(cols) - 1)
294322

295323
if len(rows) > 0:
296324
margin = data[rows].groupby(rows).apply(aggfunc)

pandas/tools/tests/test_pivot.py

+91-48
Original file line numberDiff line numberDiff line change
@@ -224,82 +224,106 @@ def test_pivot_with_tz(self):
224224
tm.assert_frame_equal(pv, expected)
225225

226226
def test_margins(self):
227-
def _check_output(res, col, index=['A', 'B'], columns=['C']):
228-
cmarg = res['All'][:-1]
229-
exp = self.data.groupby(index)[col].mean()
230-
tm.assert_series_equal(cmarg, exp, check_names=False)
231-
self.assertEqual(cmarg.name, 'All')
232-
233-
res = res.sortlevel()
234-
rmarg = res.xs(('All', ''))[:-1]
235-
exp = self.data.groupby(columns)[col].mean()
236-
tm.assert_series_equal(rmarg, exp, check_names=False)
237-
self.assertEqual(rmarg.name, ('All', ''))
238-
239-
gmarg = res['All']['All', '']
240-
exp = self.data[col].mean()
241-
self.assertEqual(gmarg, exp)
227+
def _check_output(result, values_col, index=['A', 'B'],
228+
columns=['C'],
229+
margins_col='All'):
230+
col_margins = result.ix[:-1, margins_col]
231+
expected_col_margins = self.data.groupby(index)[values_col].mean()
232+
tm.assert_series_equal(col_margins, expected_col_margins,
233+
check_names=False)
234+
self.assertEqual(col_margins.name, margins_col)
235+
236+
result = result.sortlevel()
237+
index_margins = result.ix[(margins_col, '')].iloc[:-1]
238+
expected_ix_margins = self.data.groupby(columns)[values_col].mean()
239+
tm.assert_series_equal(index_margins, expected_ix_margins,
240+
check_names=False)
241+
self.assertEqual(index_margins.name, (margins_col, ''))
242+
243+
grand_total_margins = result.loc[(margins_col, ''), margins_col]
244+
expected_total_margins = self.data[values_col].mean()
245+
self.assertEqual(grand_total_margins, expected_total_margins)
242246

243247
# column specified
244-
table = self.data.pivot_table('D', index=['A', 'B'], columns='C',
245-
margins=True, aggfunc=np.mean)
246-
_check_output(table, 'D')
248+
result = self.data.pivot_table(values='D', index=['A', 'B'],
249+
columns='C',
250+
margins=True, aggfunc=np.mean)
251+
_check_output(result, 'D')
252+
253+
# Set a different margins_name (not 'All')
254+
result = self.data.pivot_table(values='D', index=['A', 'B'],
255+
columns='C',
256+
margins=True, aggfunc=np.mean,
257+
margins_name='Totals')
258+
_check_output(result, 'D', margins_col='Totals')
247259

248260
# no column specified
249261
table = self.data.pivot_table(index=['A', 'B'], columns='C',
250262
margins=True, aggfunc=np.mean)
251-
for valcol in table.columns.levels[0]:
252-
_check_output(table[valcol], valcol)
263+
for value_col in table.columns.levels[0]:
264+
_check_output(table[value_col], value_col)
253265

254266
# no col
255267

256268
# to help with a buglet
257269
self.data.columns = [k * 2 for k in self.data.columns]
258270
table = self.data.pivot_table(index=['AA', 'BB'], margins=True,
259271
aggfunc=np.mean)
260-
for valcol in table.columns:
261-
gmarg = table[valcol]['All', '']
262-
self.assertEqual(gmarg, self.data[valcol].mean())
263-
264-
# this is OK
265-
table = self.data.pivot_table(index=['AA', 'BB'], margins=True,
266-
aggfunc='mean')
272+
for value_col in table.columns:
273+
totals = table.loc[('All', ''), value_col]
274+
self.assertEqual(totals, self.data[value_col].mean())
267275

268276
# no rows
269277
rtable = self.data.pivot_table(columns=['AA', 'BB'], margins=True,
270278
aggfunc=np.mean)
271279
tm.assertIsInstance(rtable, Series)
280+
281+
table = self.data.pivot_table(index=['AA', 'BB'], margins=True,
282+
aggfunc='mean')
272283
for item in ['DD', 'EE', 'FF']:
273-
gmarg = table[item]['All', '']
274-
self.assertEqual(gmarg, self.data[item].mean())
284+
totals = table.loc[('All', ''), item]
285+
self.assertEqual(totals, self.data[item].mean())
275286

276287
# issue number #8349: pivot_table with margins and dictionary aggfunc
288+
data = [
289+
{'JOB': 'Worker', 'NAME': 'Bob', 'YEAR': 2013,
290+
'MONTH': 12, 'DAYS': 3, 'SALARY': 17},
291+
{'JOB': 'Employ', 'NAME':
292+
'Mary', 'YEAR': 2013, 'MONTH': 12, 'DAYS': 5, 'SALARY': 23},
293+
{'JOB': 'Worker', 'NAME': 'Bob', 'YEAR': 2014,
294+
'MONTH': 1, 'DAYS': 10, 'SALARY': 100},
295+
{'JOB': 'Worker', 'NAME': 'Bob', 'YEAR': 2014,
296+
'MONTH': 1, 'DAYS': 11, 'SALARY': 110},
297+
{'JOB': 'Employ', 'NAME': 'Mary', 'YEAR': 2014,
298+
'MONTH': 1, 'DAYS': 15, 'SALARY': 200},
299+
{'JOB': 'Worker', 'NAME': 'Bob', 'YEAR': 2014,
300+
'MONTH': 2, 'DAYS': 8, 'SALARY': 80},
301+
{'JOB': 'Employ', 'NAME': 'Mary', 'YEAR': 2014,
302+
'MONTH': 2, 'DAYS': 5, 'SALARY': 190},
303+
]
277304

278-
df=DataFrame([ {'JOB':'Worker','NAME':'Bob' ,'YEAR':2013,'MONTH':12,'DAYS': 3,'SALARY': 17},
279-
{'JOB':'Employ','NAME':'Mary','YEAR':2013,'MONTH':12,'DAYS': 5,'SALARY': 23},
280-
{'JOB':'Worker','NAME':'Bob' ,'YEAR':2014,'MONTH': 1,'DAYS':10,'SALARY':100},
281-
{'JOB':'Worker','NAME':'Bob' ,'YEAR':2014,'MONTH': 1,'DAYS':11,'SALARY':110},
282-
{'JOB':'Employ','NAME':'Mary','YEAR':2014,'MONTH': 1,'DAYS':15,'SALARY':200},
283-
{'JOB':'Worker','NAME':'Bob' ,'YEAR':2014,'MONTH': 2,'DAYS': 8,'SALARY': 80},
284-
{'JOB':'Employ','NAME':'Mary','YEAR':2014,'MONTH': 2,'DAYS': 5,'SALARY':190} ])
285-
286-
df=df.set_index(['JOB','NAME','YEAR','MONTH'],drop=False,append=False)
287-
288-
rs=df.pivot_table( index=['JOB','NAME'],
289-
columns=['YEAR','MONTH'],
290-
values=['DAYS','SALARY'],
291-
aggfunc={'DAYS':'mean','SALARY':'sum'},
292-
margins=True)
305+
df = DataFrame(data)
293306

294-
ex=df.pivot_table(index=['JOB','NAME'],columns=['YEAR','MONTH'],values=['DAYS'],aggfunc='mean',margins=True)
307+
df = df.set_index(['JOB', 'NAME', 'YEAR', 'MONTH'], drop=False,
308+
append=False)
295309

296-
tm.assert_frame_equal(rs['DAYS'], ex['DAYS'])
310+
result = df.pivot_table(index=['JOB', 'NAME'],
311+
columns=['YEAR', 'MONTH'],
312+
values=['DAYS', 'SALARY'],
313+
aggfunc={'DAYS': 'mean', 'SALARY': 'sum'},
314+
margins=True)
297315

298-
ex=df.pivot_table(index=['JOB','NAME'],columns=['YEAR','MONTH'],values=['SALARY'],aggfunc='sum',margins=True)
316+
expected = df.pivot_table(index=['JOB', 'NAME'],
317+
columns=['YEAR', 'MONTH'], values=['DAYS'],
318+
aggfunc='mean', margins=True)
299319

300-
tm.assert_frame_equal(rs['SALARY'], ex['SALARY'])
320+
tm.assert_frame_equal(result['DAYS'], expected['DAYS'])
301321

322+
expected = df.pivot_table(index=['JOB', 'NAME'],
323+
columns=['YEAR', 'MONTH'], values=['SALARY'],
324+
aggfunc='sum', margins=True)
302325

326+
tm.assert_frame_equal(result['SALARY'], expected['SALARY'])
303327

304328
def test_pivot_integer_columns(self):
305329
# caused by upstream bug in unstack
@@ -402,6 +426,25 @@ def test_margins_no_values_two_row_two_cols(self):
402426
result = self.data[['A', 'B', 'C', 'D']].pivot_table(index=['A', 'B'], columns=['C', 'D'], aggfunc=len, margins=True)
403427
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])
404428

429+
def test_pivot_table_with_margins_set_margin_name(self):
430+
# GH 3335
431+
for margin_name in ['foo', 'one', 666, None, ['a', 'b']]:
432+
with self.assertRaises(ValueError):
433+
# multi-index index
434+
pivot_table(self.data, values='D', index=['A', 'B'],
435+
columns=['C'], margins=True,
436+
margins_name=margin_name)
437+
with self.assertRaises(ValueError):
438+
# multi-index column
439+
pivot_table(self.data, values='D', index=['C'],
440+
columns=['A', 'B'], margins=True,
441+
margins_name=margin_name)
442+
with self.assertRaises(ValueError):
443+
# non-multi-index index/column
444+
pivot_table(self.data, values='D', index=['A'],
445+
columns=['B'], margins=True,
446+
margins_name=margin_name)
447+
405448
def test_pivot_timegrouper(self):
406449
df = DataFrame({
407450
'Branch' : 'A A A A A A A B'.split(),

0 commit comments

Comments
 (0)