Skip to content

Commit daa5942

Browse files
authored
REF: casting in _python_agg_general (#38235)
1 parent bda4bc3 commit daa5942

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

pandas/core/groupby/groupby.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class providing the base-class of operations.
5050
from pandas.errors import AbstractMethodError
5151
from pandas.util._decorators import Appender, Substitution, cache_readonly, doc
5252

53-
from pandas.core.dtypes.cast import maybe_cast_result
53+
from pandas.core.dtypes.cast import maybe_cast_result, maybe_downcast_to_dtype
5454
from pandas.core.dtypes.common import (
5555
ensure_float,
5656
is_bool_dtype,
@@ -1185,22 +1185,24 @@ def _python_agg_general(self, func, *args, **kwargs):
11851185

11861186
assert result is not None
11871187
key = base.OutputKey(label=name, position=idx)
1188-
output[key] = maybe_cast_result(result, obj, numeric_only=True)
11891188

1190-
if not output:
1191-
return self._python_apply_general(f, self._selected_obj)
1189+
if is_numeric_dtype(obj.dtype):
1190+
result = maybe_downcast_to_dtype(result, obj.dtype)
11921191

1193-
if self.grouper._filter_empty_groups:
1194-
1195-
mask = counts.ravel() > 0
1196-
for key, result in output.items():
1192+
if self.grouper._filter_empty_groups:
1193+
mask = counts.ravel() > 0
11971194

11981195
# since we are masking, make sure that we have a float object
11991196
values = result
12001197
if is_numeric_dtype(values.dtype):
12011198
values = ensure_float(values)
12021199

1203-
output[key] = maybe_cast_result(values[mask], result)
1200+
result = maybe_downcast_to_dtype(values[mask], result.dtype)
1201+
1202+
output[key] = result
1203+
1204+
if not output:
1205+
return self._python_apply_general(f, self._selected_obj)
12041206

12051207
return self._wrap_aggregated_output(output, index=self.grouper.result_index)
12061208

pandas/core/groupby/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
718718
result[label] = res
719719

720720
result = lib.maybe_convert_objects(result, try_float=0)
721-
# TODO: maybe_cast_to_extension_array?
721+
result = maybe_cast_result(result, obj, numeric_only=True)
722722

723723
return result, counts
724724

0 commit comments

Comments
 (0)