4
4
5
5
from pandas import DataFrame
6
6
from pandas .tools .pivot import pivot_table
7
- from pandas .util .testing import assert_frame_equal
7
+ import pandas .util .testing as tm
8
8
9
9
class TestPivotTable (unittest .TestCase ):
10
10
@@ -28,7 +28,7 @@ def test_pivot_table(self):
28
28
table = pivot_table (self .data , values = 'D' , rows = rows , cols = cols )
29
29
30
30
table2 = self .data .pivot_table (values = 'D' , rows = rows , cols = cols )
31
- assert_frame_equal (table , table2 )
31
+ tm . assert_frame_equal (table , table2 )
32
32
33
33
# this works
34
34
pivot_table (self .data , values = 'D' , rows = rows )
@@ -44,24 +44,62 @@ def test_pivot_table(self):
44
44
self .assertEqual (table .columns .name , cols [0 ])
45
45
46
46
expected = self .data .groupby (rows + [cols ])['D' ].agg (np .mean ).unstack ()
47
- assert_frame_equal (table , expected )
47
+ tm . assert_frame_equal (table , expected )
48
48
49
49
def test_pivot_table_multiple (self ):
50
50
rows = ['A' , 'B' ]
51
51
cols = 'C'
52
52
table = pivot_table (self .data , rows = rows , cols = cols )
53
53
expected = self .data .groupby (rows + [cols ]).agg (np .mean ).unstack ()
54
- assert_frame_equal (table , expected )
54
+ tm . assert_frame_equal (table , expected )
55
55
56
56
def test_pivot_multi_values (self ):
57
57
result = pivot_table (self .data , values = ['D' , 'E' ],
58
58
rows = 'A' , cols = ['B' , 'C' ], fill_value = 0 )
59
59
expected = pivot_table (self .data .drop (['F' ], axis = 1 ),
60
60
rows = 'A' , cols = ['B' , 'C' ], fill_value = 0 )
61
- assert_frame_equal (result , expected )
61
+ tm . assert_frame_equal (result , expected )
62
62
63
63
def test_margins (self ):
64
- pass
64
+ def _check_output (res , col , rows = ['A' , 'B' ], cols = ['C' ]):
65
+ cmarg = res ['All' ][:- 1 ]
66
+ exp = self .data .groupby (rows )[col ].mean ()
67
+ tm .assert_series_equal (cmarg , exp )
68
+
69
+ rmarg = res .xs (('All' , '' ))[:- 1 ]
70
+ exp = self .data .groupby (cols )[col ].mean ()
71
+ tm .assert_series_equal (rmarg , exp )
72
+
73
+ gmarg = res ['All' ]['All' , '' ]
74
+ exp = self .data [col ].mean ()
75
+ self .assertEqual (gmarg , exp )
76
+
77
+ # column specified
78
+ table = self .data .pivot_table ('D' , rows = ['A' , 'B' ], cols = 'C' ,
79
+ margins = True , aggfunc = np .mean )
80
+ _check_output (table , 'D' )
81
+
82
+ # no column specified
83
+ table = self .data .pivot_table (rows = ['A' , 'B' ], cols = 'C' ,
84
+ margins = True , aggfunc = np .mean )
85
+ for valcol in table .columns .levels [0 ]:
86
+ _check_output (table [valcol ], valcol )
87
+
88
+ # no col
89
+ table = self .data .pivot_table (rows = ['A' , 'B' ], margins = True ,
90
+ aggfunc = np .mean )
91
+ for valcol in table .columns :
92
+ gmarg = table [valcol ]['All' , '' ]
93
+ self .assertEqual (gmarg , self .data [valcol ].mean ())
94
+
95
+ # doesn't quite work yet
96
+
97
+ # # no rows
98
+ # table = self.data.pivot_table(cols=['A', 'B'], margins=True,
99
+ # aggfunc=np.mean)
100
+ # for valcol in table.columns:
101
+ # gmarg = table[valcol]['All', '']
102
+ # self.assertEqual(gmarg, self.data[valcol].mean())
65
103
66
104
if __name__ == '__main__' :
67
105
import nose
0 commit comments