@@ -785,6 +785,8 @@ def _try_cast(self, result, obj, numeric_only=False):
785
785
elif is_extension_array_dtype (dtype ):
786
786
# The function can return something of any type, so check
787
787
# if the type is compatible with the calling EA.
788
+
789
+ # return the same type (Series) as our caller
788
790
try :
789
791
result = obj ._values ._from_sequence (result , dtype = dtype )
790
792
except Exception :
@@ -1156,7 +1158,8 @@ def mean(self, *args, **kwargs):
1156
1158
"""
1157
1159
nv .validate_groupby_func ('mean' , args , kwargs , ['numeric_only' ])
1158
1160
try :
1159
- return self ._cython_agg_general ('mean' , ** kwargs )
1161
+ return self ._cython_agg_general (
1162
+ 'mean' , alt = lambda x , axis : Series (x ).mean (** kwargs ), ** kwargs )
1160
1163
except GroupByError :
1161
1164
raise
1162
1165
except Exception : # pragma: no cover
@@ -1178,7 +1181,11 @@ def median(self, **kwargs):
1178
1181
Median of values within each group.
1179
1182
"""
1180
1183
try :
1181
- return self ._cython_agg_general ('median' , ** kwargs )
1184
+ return self ._cython_agg_general (
1185
+ 'median' ,
1186
+ alt = lambda x ,
1187
+ axis : Series (x ).median (** kwargs ),
1188
+ ** kwargs )
1182
1189
except GroupByError :
1183
1190
raise
1184
1191
except Exception : # pragma: no cover
@@ -1234,7 +1241,10 @@ def var(self, ddof=1, *args, **kwargs):
1234
1241
nv .validate_groupby_func ('var' , args , kwargs )
1235
1242
if ddof == 1 :
1236
1243
try :
1237
- return self ._cython_agg_general ('var' , ** kwargs )
1244
+ return self ._cython_agg_general (
1245
+ 'var' ,
1246
+ alt = lambda x , axis : Series (x ).var (ddof = ddof , ** kwargs ),
1247
+ ** kwargs )
1238
1248
except Exception :
1239
1249
f = lambda x : x .var (ddof = ddof , ** kwargs )
1240
1250
with _group_selection_context (self ):
@@ -1262,7 +1272,6 @@ def sem(self, ddof=1):
1262
1272
Series or DataFrame
1263
1273
Standard error of the mean of values within each group.
1264
1274
"""
1265
-
1266
1275
return self .std (ddof = ddof ) / np .sqrt (self .count ())
1267
1276
1268
1277
@Substitution (name = 'groupby' )
@@ -1319,6 +1328,16 @@ def f(self, **kwargs):
1319
1328
except Exception :
1320
1329
result = self .aggregate (
1321
1330
lambda x : npfunc (x , axis = self .axis ))
1331
+
1332
+ # coerce the columns if we can
1333
+ if isinstance (result , DataFrame ):
1334
+ for col in result .columns :
1335
+ result [col ] = self ._try_cast (
1336
+ result [col ], self .obj [col ])
1337
+ else :
1338
+ result = self ._try_cast (
1339
+ result , self .obj )
1340
+
1322
1341
if _convert :
1323
1342
result = result ._convert (datetime = True )
1324
1343
return result
0 commit comments