Skip to content

Commit 37ed7e4

Browse files
authored
REF: consolidate result-casting in _cython_operation (#38237)
1 parent daa5942 commit 37ed7e4

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

pandas/core/groupby/groupby.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1000,9 +1000,6 @@ def _cython_transform(
10001000
except NotImplementedError:
10011001
continue
10021002

1003-
if self._transform_should_cast(how):
1004-
result = maybe_cast_result(result, obj, how=how)
1005-
10061003
key = base.OutputKey(label=name, position=idx)
10071004
output[key] = result
10081005

@@ -1081,12 +1078,12 @@ def _cython_agg_general(
10811078
assert len(agg_names) == result.shape[1]
10821079
for result_column, result_name in zip(result.T, agg_names):
10831080
key = base.OutputKey(label=result_name, position=idx)
1084-
output[key] = maybe_cast_result(result_column, obj, how=how)
1081+
output[key] = result_column
10851082
idx += 1
10861083
else:
10871084
assert result.ndim == 1
10881085
key = base.OutputKey(label=name, position=idx)
1089-
output[key] = maybe_cast_result(result, obj, how=how)
1086+
output[key] = result
10901087
idx += 1
10911088

10921089
if not output:

pandas/core/groupby/ops.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
from pandas.errors import AbstractMethodError
2929
from pandas.util._decorators import cache_readonly
3030

31-
from pandas.core.dtypes.cast import maybe_cast_result
31+
from pandas.core.dtypes.cast import (
32+
maybe_cast_result,
33+
maybe_cast_result_dtype,
34+
maybe_downcast_to_dtype,
35+
)
3236
from pandas.core.dtypes.common import (
3337
ensure_float,
3438
ensure_float64,
@@ -620,8 +624,11 @@ def _cython_operation(
620624
if swapped:
621625
result = result.swapaxes(0, axis)
622626

623-
if is_datetimelike and kind == "aggregate":
624-
result = result.astype(orig_values.dtype)
627+
if how not in base.cython_cast_blocklist:
628+
# e.g. if we are int64 and need to restore to datetime64/timedelta64
629+
# "rank" is the only member of cython_cast_blocklist we get here
630+
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
631+
result = maybe_downcast_to_dtype(result, dtype)
625632

626633
return result, names
627634

0 commit comments

Comments
 (0)