2
2
import numpy as np
3
3
4
4
def pivot_table (data , values = None , rows = None , cols = None , aggfunc = np .mean ,
5
- fill_value = None ):
5
+ fill_value = None , margins = False ):
6
6
"""
7
7
Create a spreadsheet-style pivot table as a DataFrame. The levels in the
8
8
pivot table will be stored in MultiIndex objects (hierarchical indexes) on
@@ -19,6 +19,8 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
19
19
aggfunc : function, default numpy.mean
20
20
fill_value : scalar, default None
21
21
Value to replace missing values with
22
+ margins : boolean, default False
23
+ Add all row / columns (e.g. for subtotal / grand totals)
22
24
23
25
Examples
24
26
--------
@@ -59,15 +61,14 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
59
61
else :
60
62
values_multi = False
61
63
values = [values ]
64
+ else :
65
+ values = list (data .columns .drop (keys ))
62
66
63
67
if values_passed :
64
68
data = data [keys + values ]
65
69
66
70
grouped = data .groupby (keys )
67
71
68
- if values_passed and not values_multi :
69
- grouped = grouped [values [0 ]]
70
-
71
72
agged = grouped .agg (aggfunc )
72
73
73
74
table = agged
@@ -77,10 +78,61 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
77
78
if fill_value is not None :
78
79
table = table .fillna (value = fill_value )
79
80
81
+ if margins :
82
+ table = _add_margins (table , data , values , rows = rows ,
83
+ cols = cols , aggfunc = aggfunc )
84
+
85
+ # discard the top level
86
+ if values_passed and not values_multi :
87
+ table = table [values [0 ]]
88
+
80
89
return table
81
90
82
91
DataFrame .pivot_table = pivot_table
83
92
93
+ def _add_margins (table , data , values , rows = None , cols = None , aggfunc = np .mean ):
94
+ if rows is not None :
95
+ col_margin = data [rows + values ].groupby (rows ).agg (aggfunc )
96
+
97
+ # need to "interleave" the margins
98
+
99
+ table_pieces = []
100
+ margin_keys = []
101
+ for key , piece in table .groupby (level = 0 , axis = 1 ):
102
+ all_key = (key , 'All' ) + ('' ,) * (len (cols ) - 1 )
103
+ piece [all_key ] = col_margin [key ]
104
+ table_pieces .append (piece )
105
+ margin_keys .append (all_key )
106
+
107
+ result = table_pieces [0 ]
108
+ for piece in table_pieces [1 :]:
109
+ result = result .join (piece )
110
+ else :
111
+ result = table
112
+ margin_keys = []
113
+
114
+ grand_margin = data [values ].apply (aggfunc )
115
+
116
+ if cols is not None :
117
+ row_margin = data [cols + values ].groupby (cols ).agg (aggfunc )
118
+ row_margin = row_margin .stack ()
119
+
120
+ # slight hack
121
+ new_order = [len (cols )] + range (len (cols ))
122
+ row_margin .index = row_margin .index .reorder_levels (new_order )
123
+
124
+ key = ('All' ,) + ('' ,) * (len (rows ) - 1 )
125
+
126
+ row_margin = row_margin .reindex (result .columns )
127
+ # populate grand margin
128
+ for k in margin_keys :
129
+ row_margin [k ] = grand_margin [k [0 ]]
130
+
131
+ margin_dummy = DataFrame (row_margin , columns = [key ]).T
132
+ result = result .append (margin_dummy )
133
+
134
+ return result
135
+
84
136
def _convert_by (by ):
85
137
if by is None :
86
138
by = []
0 commit comments