@@ -784,6 +784,8 @@ def _try_cast(self, result, obj, numeric_only=False):
784
784
elif is_extension_array_dtype (dtype ):
785
785
# The function can return something of any type, so check
786
786
# if the type is compatible with the calling EA.
787
+
788
+ # return the same type (Series) as our caller
787
789
try :
788
790
result = obj ._values ._from_sequence (result , dtype = dtype )
789
791
except Exception :
@@ -1155,7 +1157,8 @@ def mean(self, *args, **kwargs):
1155
1157
"""
1156
1158
nv .validate_groupby_func ('mean' , args , kwargs , ['numeric_only' ])
1157
1159
try :
1158
- return self ._cython_agg_general ('mean' , ** kwargs )
1160
+ return self ._cython_agg_general (
1161
+ 'mean' , alt = lambda x , axis : Series (x ).mean (** kwargs ), ** kwargs )
1159
1162
except GroupByError :
1160
1163
raise
1161
1164
except Exception : # pragma: no cover
@@ -1177,7 +1180,11 @@ def median(self, **kwargs):
1177
1180
Median of values within each group.
1178
1181
"""
1179
1182
try :
1180
- return self ._cython_agg_general ('median' , ** kwargs )
1183
+ return self ._cython_agg_general (
1184
+ 'median' ,
1185
+ alt = lambda x ,
1186
+ axis : Series (x ).median (** kwargs ),
1187
+ ** kwargs )
1181
1188
except GroupByError :
1182
1189
raise
1183
1190
except Exception : # pragma: no cover
@@ -1233,7 +1240,10 @@ def var(self, ddof=1, *args, **kwargs):
1233
1240
nv .validate_groupby_func ('var' , args , kwargs )
1234
1241
if ddof == 1 :
1235
1242
try :
1236
- return self ._cython_agg_general ('var' , ** kwargs )
1243
+ return self ._cython_agg_general (
1244
+ 'var' ,
1245
+ alt = lambda x , axis : Series (x ).var (ddof = ddof , ** kwargs ),
1246
+ ** kwargs )
1237
1247
except Exception :
1238
1248
f = lambda x : x .var (ddof = ddof , ** kwargs )
1239
1249
with _group_selection_context (self ):
@@ -1261,7 +1271,6 @@ def sem(self, ddof=1):
1261
1271
Series or DataFrame
1262
1272
Standard error of the mean of values within each group.
1263
1273
"""
1264
-
1265
1274
return self .std (ddof = ddof ) / np .sqrt (self .count ())
1266
1275
1267
1276
@Substitution (name = 'groupby' )
@@ -1318,6 +1327,16 @@ def f(self, **kwargs):
1318
1327
except Exception :
1319
1328
result = self .aggregate (
1320
1329
lambda x : npfunc (x , axis = self .axis ))
1330
+
1331
+ # coerce the columns if we can
1332
+ if isinstance (result , DataFrame ):
1333
+ for col in result .columns :
1334
+ result [col ] = self ._try_cast (
1335
+ result [col ], self .obj [col ])
1336
+ else :
1337
+ result = self ._try_cast (
1338
+ result , self .obj )
1339
+
1321
1340
if _convert :
1322
1341
result = result ._convert (datetime = True )
1323
1342
return result
0 commit comments