@@ -91,48 +91,51 @@ def setUp(self):
91
91
'F' : np .random .randn (11 )})
92
92
93
93
def test_basic (self ):
94
- data = Series (np .arange (9 ) // 3 , index = np .arange (9 ))
95
94
96
- index = np .arange (9 )
97
- np .random .shuffle (index )
98
- data = data .reindex (index )
95
+ def checkit (dtype ):
96
+ data = Series (np .arange (9 ) // 3 , index = np .arange (9 ), dtype = dtype )
99
97
100
- grouped = data .groupby (lambda x : x // 3 )
98
+ index = np .arange (9 )
99
+ np .random .shuffle (index )
100
+ data = data .reindex (index )
101
101
102
- for k , v in grouped :
103
- self .assertEqual (len (v ), 3 )
102
+ grouped = data .groupby (lambda x : x // 3 )
104
103
105
- agged = grouped . aggregate ( np . mean )
106
- self .assertEqual (agged [ 1 ], 1 )
104
+ for k , v in grouped :
105
+ self .assertEqual (len ( v ), 3 )
107
106
108
- assert_series_equal ( agged , grouped .agg (np .mean )) # shorthand
109
- assert_series_equal (agged , grouped . mean () )
107
+ agged = grouped .aggregate (np .mean )
108
+ self . assertEqual (agged [ 1 ], 1 )
110
109
111
- # Cython only returning floating point for now...
112
- assert_series_equal (grouped .agg ( np . sum ). astype ( float ),
113
- grouped .sum ())
110
+ assert_series_equal ( agged , grouped . agg ( np . mean )) # shorthand
111
+ assert_series_equal (agged , grouped .mean ())
112
+ assert_series_equal ( grouped . agg ( np . sum ), grouped .sum ())
114
113
115
- transformed = grouped .transform (lambda x : x * x .sum ())
116
- self .assertEqual (transformed [7 ], 12 )
114
+ transformed = grouped .transform (lambda x : x * x .sum ())
115
+ self .assertEqual (transformed [7 ], 12 )
117
116
118
- value_grouped = data .groupby (data )
119
- assert_series_equal (value_grouped .aggregate (np .mean ), agged )
117
+ value_grouped = data .groupby (data )
118
+ assert_series_equal (value_grouped .aggregate (np .mean ), agged )
120
119
121
- # complex agg
122
- agged = grouped .aggregate ([np .mean , np .std ])
123
- agged = grouped .aggregate ({'one' : np .mean ,
124
- 'two' : np .std })
120
+ # complex agg
121
+ agged = grouped .aggregate ([np .mean , np .std ])
122
+ agged = grouped .aggregate ({'one' : np .mean ,
123
+ 'two' : np .std })
124
+
125
+ group_constants = {
126
+ 0 : 10 ,
127
+ 1 : 20 ,
128
+ 2 : 30
129
+ }
130
+ agged = grouped .agg (lambda x : group_constants [x .name ] + x .mean ())
131
+ self .assertEqual (agged [1 ], 21 )
125
132
126
- group_constants = {
127
- 0 : 10 ,
128
- 1 : 20 ,
129
- 2 : 30
130
- }
131
- agged = grouped .agg (lambda x : group_constants [x .name ] + x .mean ())
132
- self .assertEqual (agged [1 ], 21 )
133
+ # corner cases
134
+ self .assertRaises (Exception , grouped .aggregate , lambda x : x * 2 )
133
135
134
- # corner cases
135
- self .assertRaises (Exception , grouped .aggregate , lambda x : x * 2 )
136
+
137
+ for dtype in ['int64' ,'int32' ,'float64' ,'float32' ]:
138
+ checkit (dtype )
136
139
137
140
def test_first_last_nth (self ):
138
141
# tests for first / last / nth
@@ -185,6 +188,14 @@ def test_first_last_nth_dtypes(self):
185
188
expected .index = ['bar' , 'foo' ]
186
189
assert_frame_equal (nth , expected , check_names = False )
187
190
191
+ # GH 2763, first/last shifting dtypes
192
+ idx = range (10 )
193
+ idx .append (9 )
194
+ s = Series (data = range (11 ), index = idx , name = 'IntCol' )
195
+ self .assert_ (s .dtype == 'int64' )
196
+ f = s .groupby (level = 0 ).first ()
197
+ self .assert_ (f .dtype == 'int64' )
198
+
188
199
def test_grouper_iter (self ):
189
200
self .assertEqual (sorted (self .df .groupby ('A' ).grouper ), ['bar' , 'foo' ])
190
201
0 commit comments