@@ -786,6 +786,8 @@ def _try_cast(self, result, obj, numeric_only=False):
786
786
elif is_extension_array_dtype (dtype ):
787
787
# The function can return something of any type, so check
788
788
# if the type is compatible with the calling EA.
789
+
790
+ # return the same type (Series) as our caller
789
791
try :
790
792
result = obj ._values ._from_sequence (result , dtype = dtype )
791
793
except Exception :
@@ -1157,7 +1159,8 @@ def mean(self, *args, **kwargs):
1157
1159
"""
1158
1160
nv .validate_groupby_func ('mean' , args , kwargs , ['numeric_only' ])
1159
1161
try :
1160
- return self ._cython_agg_general ('mean' , ** kwargs )
1162
+ return self ._cython_agg_general (
1163
+ 'mean' , alt = lambda x , axis : Series (x ).mean (** kwargs ), ** kwargs )
1161
1164
except GroupByError :
1162
1165
raise
1163
1166
except Exception : # pragma: no cover
@@ -1179,7 +1182,11 @@ def median(self, **kwargs):
1179
1182
Median of values within each group.
1180
1183
"""
1181
1184
try :
1182
- return self ._cython_agg_general ('median' , ** kwargs )
1185
+ return self ._cython_agg_general (
1186
+ 'median' ,
1187
+ alt = lambda x ,
1188
+ axis : Series (x ).median (axis = axis , ** kwargs ),
1189
+ ** kwargs )
1183
1190
except GroupByError :
1184
1191
raise
1185
1192
except Exception : # pragma: no cover
@@ -1235,7 +1242,10 @@ def var(self, ddof=1, *args, **kwargs):
1235
1242
nv .validate_groupby_func ('var' , args , kwargs )
1236
1243
if ddof == 1 :
1237
1244
try :
1238
- return self ._cython_agg_general ('var' , ** kwargs )
1245
+ return self ._cython_agg_general (
1246
+ 'var' ,
1247
+ alt = lambda x , axis : Series (x ).var (ddof = ddof , ** kwargs ),
1248
+ ** kwargs )
1239
1249
except Exception :
1240
1250
f = lambda x : x .var (ddof = ddof , ** kwargs )
1241
1251
with _group_selection_context (self ):
@@ -1263,7 +1273,6 @@ def sem(self, ddof=1):
1263
1273
Series or DataFrame
1264
1274
Standard error of the mean of values within each group.
1265
1275
"""
1266
-
1267
1276
return self .std (ddof = ddof ) / np .sqrt (self .count ())
1268
1277
1269
1278
@Substitution (name = 'groupby' )
@@ -1290,7 +1299,7 @@ def _add_numeric_operations(cls):
1290
1299
"""
1291
1300
1292
1301
def groupby_function (name , alias , npfunc ,
1293
- numeric_only = True , _convert = False ,
1302
+ numeric_only = True ,
1294
1303
min_count = - 1 ):
1295
1304
1296
1305
_local_template = """
@@ -1312,17 +1321,30 @@ def f(self, **kwargs):
1312
1321
kwargs ['min_count' ] = min_count
1313
1322
1314
1323
self ._set_group_selection ()
1324
+
1325
+ # try a cython aggregation if we can
1315
1326
try :
1316
1327
return self ._cython_agg_general (
1317
1328
alias , alt = npfunc , ** kwargs )
1318
1329
except AssertionError as e :
1319
1330
raise SpecificationError (str (e ))
1320
1331
except Exception :
1321
- result = self .aggregate (
1322
- lambda x : npfunc (x , axis = self .axis ))
1323
- if _convert :
1324
- result = result ._convert (datetime = True )
1325
- return result
1332
+ pass
1333
+
1334
+ # apply a non-cython aggregation
1335
+ result = self .aggregate (
1336
+ lambda x : npfunc (x , axis = self .axis ))
1337
+
1338
+ # coerce the resulting columns if we can
1339
+ if isinstance (result , DataFrame ):
1340
+ for col in result .columns :
1341
+ result [col ] = self ._try_cast (
1342
+ result [col ], self .obj [col ])
1343
+ else :
1344
+ result = self ._try_cast (
1345
+ result , self .obj )
1346
+
1347
+ return result
1326
1348
1327
1349
set_function_name (f , name , cls )
1328
1350
0 commit comments