Skip to content

Commit e874b4c

Browse files
committed
TST: testing of margins in pivot_table, GH #114
1 parent a0c5090 commit e874b4c

File tree

2 files changed

+70
-43
lines changed

2 files changed

+70
-43
lines changed

pandas/tools/pivot.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pandas import DataFrame
1+
from pandas import Series, DataFrame
22
import numpy as np
33

44
def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
@@ -91,14 +91,20 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
9191
DataFrame.pivot_table = pivot_table
9292

9393
def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
94-
if rows is not None:
94+
if len(cols) > 0:
9595
col_margin = data[rows + values].groupby(rows).agg(aggfunc)
9696

9797
# need to "interleave" the margins
9898

9999
table_pieces = []
100100
margin_keys = []
101-
for key, piece in table.groupby(level=0, axis=1):
101+
102+
if len(cols) > 0:
103+
grouper = table.groupby(level=0, axis=1)
104+
else:
105+
grouper = ((k, table[[k]]) for k in table.columns)
106+
107+
for key, piece in grouper:
102108
all_key = (key, 'All') + ('',) * (len(cols) - 1)
103109
piece[all_key] = col_margin[key]
104110
table_pieces.append(piece)
@@ -109,27 +115,34 @@ def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
109115
result = result.join(piece)
110116
else:
111117
result = table
112-
margin_keys = []
118+
margin_keys = table.columns
113119

114-
grand_margin = data[values].apply(aggfunc)
120+
grand_margin = {}
121+
for k, v in data[values].iteritems():
122+
try:
123+
grand_margin[k] = aggfunc(v)
124+
except TypeError:
125+
pass
115126

116-
if cols is not None:
127+
if len(cols) > 0:
117128
row_margin = data[cols + values].groupby(cols).agg(aggfunc)
118129
row_margin = row_margin.stack()
119130

120131
# slight hack
121132
new_order = [len(cols)] + range(len(cols))
122133
row_margin.index = row_margin.index.reorder_levels(new_order)
134+
else:
135+
row_margin = Series(np.nan, index=result.columns)
123136

124-
key = ('All',) + ('',) * (len(rows) - 1)
137+
key = ('All',) + ('',) * (len(rows) - 1)
125138

126-
row_margin = row_margin.reindex(result.columns)
127-
# populate grand margin
128-
for k in margin_keys:
129-
row_margin[k] = grand_margin[k[0]]
139+
row_margin = row_margin.reindex(result.columns)
140+
# populate grand margin
141+
for k in margin_keys:
142+
row_margin[k] = grand_margin[k[0]]
130143

131-
margin_dummy = DataFrame(row_margin, columns=[key]).T
132-
result = result.append(margin_dummy)
144+
margin_dummy = DataFrame(row_margin, columns=[key]).T
145+
result = result.append(margin_dummy)
133146

134147
return result
135148

@@ -142,27 +155,3 @@ def _convert_by(by):
142155
by = list(by)
143156
return by
144157

145-
146-
if __name__ == '__main__':
147-
def _sample(values, n):
148-
indexer = np.random.randint(0, len(values), n)
149-
return np.asarray(values).take(indexer)
150-
151-
levels = [['a', 'b', 'c', 'd'],
152-
['foo', 'bar', 'baz'],
153-
['one', 'two'],
154-
['US', 'JP', 'UK']]
155-
names = ['k1', 'k2', 'k3', 'k4']
156-
157-
n = 100000
158-
159-
data = {}
160-
for name, level in zip(names, levels):
161-
data[name] = _sample(level, n)
162-
163-
data['values'] = np.random.randn(n)
164-
data = DataFrame(data)
165-
166-
table = pivot_table(data, values='values',
167-
rows=['k1', 'k2'], cols=['k3', 'k4'])
168-

pandas/tools/tests/test_pivot.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pandas import DataFrame
66
from pandas.tools.pivot import pivot_table
7-
from pandas.util.testing import assert_frame_equal
7+
import pandas.util.testing as tm
88

99
class TestPivotTable(unittest.TestCase):
1010

@@ -28,7 +28,7 @@ def test_pivot_table(self):
2828
table = pivot_table(self.data, values='D', rows=rows, cols=cols)
2929

3030
table2 = self.data.pivot_table(values='D', rows=rows, cols=cols)
31-
assert_frame_equal(table, table2)
31+
tm.assert_frame_equal(table, table2)
3232

3333
# this works
3434
pivot_table(self.data, values='D', rows=rows)
@@ -44,24 +44,62 @@ def test_pivot_table(self):
4444
self.assertEqual(table.columns.name, cols[0])
4545

4646
expected = self.data.groupby(rows + [cols])['D'].agg(np.mean).unstack()
47-
assert_frame_equal(table, expected)
47+
tm.assert_frame_equal(table, expected)
4848

4949
def test_pivot_table_multiple(self):
5050
rows = ['A', 'B']
5151
cols= 'C'
5252
table = pivot_table(self.data, rows=rows, cols=cols)
5353
expected = self.data.groupby(rows + [cols]).agg(np.mean).unstack()
54-
assert_frame_equal(table, expected)
54+
tm.assert_frame_equal(table, expected)
5555

5656
def test_pivot_multi_values(self):
5757
result = pivot_table(self.data, values=['D', 'E'],
5858
rows='A', cols=['B', 'C'], fill_value=0)
5959
expected = pivot_table(self.data.drop(['F'], axis=1),
6060
rows='A', cols=['B', 'C'], fill_value=0)
61-
assert_frame_equal(result, expected)
61+
tm.assert_frame_equal(result, expected)
6262

6363
def test_margins(self):
64-
pass
64+
def _check_output(res, col, rows=['A', 'B'], cols=['C']):
65+
cmarg = res['All'][:-1]
66+
exp = self.data.groupby(rows)[col].mean()
67+
tm.assert_series_equal(cmarg, exp)
68+
69+
rmarg = res.xs(('All', ''))[:-1]
70+
exp = self.data.groupby(cols)[col].mean()
71+
tm.assert_series_equal(rmarg, exp)
72+
73+
gmarg = res['All']['All', '']
74+
exp = self.data[col].mean()
75+
self.assertEqual(gmarg, exp)
76+
77+
# column specified
78+
table = self.data.pivot_table('D', rows=['A', 'B'], cols='C',
79+
margins=True, aggfunc=np.mean)
80+
_check_output(table, 'D')
81+
82+
# no column specified
83+
table = self.data.pivot_table(rows=['A', 'B'], cols='C',
84+
margins=True, aggfunc=np.mean)
85+
for valcol in table.columns.levels[0]:
86+
_check_output(table[valcol], valcol)
87+
88+
# no col
89+
table = self.data.pivot_table(rows=['A', 'B'], margins=True,
90+
aggfunc=np.mean)
91+
for valcol in table.columns:
92+
gmarg = table[valcol]['All', '']
93+
self.assertEqual(gmarg, self.data[valcol].mean())
94+
95+
# doesn't quite work yet
96+
97+
# # no rows
98+
# table = self.data.pivot_table(cols=['A', 'B'], margins=True,
99+
# aggfunc=np.mean)
100+
# for valcol in table.columns:
101+
# gmarg = table[valcol]['All', '']
102+
# self.assertEqual(gmarg, self.data[valcol].mean())
65103

66104
if __name__ == '__main__':
67105
import nose

0 commit comments

Comments
 (0)