@@ -784,6 +784,8 @@ def _try_cast(self, result, obj, numeric_only=False):
784
784
elif is_extension_array_dtype (dtype ):
785
785
# The function can return something of any type, so check
786
786
# if the type is compatible with the calling EA.
787
+
788
+ # return the same type (Series) as our caller
787
789
try :
788
790
result = obj ._values ._from_sequence (result , dtype = dtype )
789
791
except Exception :
@@ -1155,7 +1157,8 @@ def mean(self, *args, **kwargs):
1155
1157
"""
1156
1158
nv .validate_groupby_func ('mean' , args , kwargs , ['numeric_only' ])
1157
1159
try :
1158
- return self ._cython_agg_general ('mean' , ** kwargs )
1160
+ return self ._cython_agg_general (
1161
+ 'mean' , alt = lambda x , axis : Series (x ).mean (), ** kwargs )
1159
1162
except GroupByError :
1160
1163
raise
1161
1164
except Exception : # pragma: no cover
@@ -1177,7 +1180,8 @@ def median(self, **kwargs):
1177
1180
Median of values within each group.
1178
1181
"""
1179
1182
try :
1180
- return self ._cython_agg_general ('median' , ** kwargs )
1183
+ return self ._cython_agg_general (
1184
+ 'median' , alt = lambda x , axis : Series (x ).median (), ** kwargs )
1181
1185
except GroupByError :
1182
1186
raise
1183
1187
except Exception : # pragma: no cover
@@ -1233,7 +1237,10 @@ def var(self, ddof=1, *args, **kwargs):
1233
1237
nv .validate_groupby_func ('var' , args , kwargs )
1234
1238
if ddof == 1 :
1235
1239
try :
1236
- return self ._cython_agg_general ('var' , ** kwargs )
1240
+ return self ._cython_agg_general (
1241
+ 'var' ,
1242
+ alt = lambda x , axis : Series (x ).var (ddof = ddof ),
1243
+ ** kwargs )
1237
1244
except Exception :
1238
1245
f = lambda x : x .var (ddof = ddof , ** kwargs )
1239
1246
with _group_selection_context (self ):
@@ -1261,7 +1268,6 @@ def sem(self, ddof=1):
1261
1268
Series or DataFrame
1262
1269
Standard error of the mean of values within each group.
1263
1270
"""
1264
-
1265
1271
return self .std (ddof = ddof ) / np .sqrt (self .count ())
1266
1272
1267
1273
@Substitution (name = 'groupby' )
@@ -1318,6 +1324,16 @@ def f(self, **kwargs):
1318
1324
except Exception :
1319
1325
result = self .aggregate (
1320
1326
lambda x : npfunc (x , axis = self .axis ))
1327
+
1328
+ # coerce the columns if we can
1329
+ if isinstance (result , DataFrame ):
1330
+ for col in result .columns :
1331
+ result [col ] = self ._try_cast (
1332
+ result [col ], self .obj [col ])
1333
+ else :
1334
+ result = self ._try_cast (
1335
+ result , self .obj )
1336
+
1321
1337
if _convert :
1322
1338
result = result ._convert (datetime = True )
1323
1339
return result
0 commit comments