@@ -786,6 +786,8 @@ def _try_cast(self, result, obj, numeric_only=False):
786
786
elif is_extension_array_dtype (dtype ):
787
787
# The function can return something of any type, so check
788
788
# if the type is compatible with the calling EA.
789
+
790
+ # return the same type (Series) as our caller
789
791
try :
790
792
result = obj ._values ._from_sequence (result , dtype = dtype )
791
793
except Exception :
@@ -1157,7 +1159,8 @@ def mean(self, *args, **kwargs):
1157
1159
"""
1158
1160
nv .validate_groupby_func ('mean' , args , kwargs , ['numeric_only' ])
1159
1161
try :
1160
- return self ._cython_agg_general ('mean' , ** kwargs )
1162
+ return self ._cython_agg_general (
1163
+ 'mean' , alt = lambda x , axis : Series (x ).mean (** kwargs ), ** kwargs )
1161
1164
except GroupByError :
1162
1165
raise
1163
1166
except Exception : # pragma: no cover
@@ -1179,7 +1182,11 @@ def median(self, **kwargs):
1179
1182
Median of values within each group.
1180
1183
"""
1181
1184
try :
1182
- return self ._cython_agg_general ('median' , ** kwargs )
1185
+ return self ._cython_agg_general (
1186
+ 'median' ,
1187
+ alt = lambda x ,
1188
+ axis : Series (x ).median (** kwargs ),
1189
+ ** kwargs )
1183
1190
except GroupByError :
1184
1191
raise
1185
1192
except Exception : # pragma: no cover
@@ -1235,7 +1242,10 @@ def var(self, ddof=1, *args, **kwargs):
1235
1242
nv .validate_groupby_func ('var' , args , kwargs )
1236
1243
if ddof == 1 :
1237
1244
try :
1238
- return self ._cython_agg_general ('var' , ** kwargs )
1245
+ return self ._cython_agg_general (
1246
+ 'var' ,
1247
+ alt = lambda x , axis : Series (x ).var (ddof = ddof , ** kwargs ),
1248
+ ** kwargs )
1239
1249
except Exception :
1240
1250
f = lambda x : x .var (ddof = ddof , ** kwargs )
1241
1251
with _group_selection_context (self ):
@@ -1263,7 +1273,6 @@ def sem(self, ddof=1):
1263
1273
Series or DataFrame
1264
1274
Standard error of the mean of values within each group.
1265
1275
"""
1266
-
1267
1276
return self .std (ddof = ddof ) / np .sqrt (self .count ())
1268
1277
1269
1278
@Substitution (name = 'groupby' )
@@ -1320,6 +1329,16 @@ def f(self, **kwargs):
1320
1329
except Exception :
1321
1330
result = self .aggregate (
1322
1331
lambda x : npfunc (x , axis = self .axis ))
1332
+
1333
+ # coerce the columns if we can
1334
+ if isinstance (result , DataFrame ):
1335
+ for col in result .columns :
1336
+ result [col ] = self ._try_cast (
1337
+ result [col ], self .obj [col ])
1338
+ else :
1339
+ result = self ._try_cast (
1340
+ result , self .obj )
1341
+
1323
1342
if _convert :
1324
1343
result = result ._convert (datetime = True )
1325
1344
return result
0 commit comments