Skip to content

Commit a997bab

Browse files
authored
REF: share GroupBy.transform (#41308)
1 parent 88ce933 commit a997bab

File tree

2 files changed

+67
-68
lines changed

2 files changed

+67
-68
lines changed

pandas/core/groupby/generic.py

+13-65
Original file line numberDiff line numberDiff line change
@@ -526,35 +526,9 @@ def _aggregate_named(self, func, *args, **kwargs):
526526
@Substitution(klass="Series")
527527
@Appender(_transform_template)
528528
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
529-
530-
if maybe_use_numba(engine):
531-
with group_selection_context(self):
532-
data = self._selected_obj
533-
result = self._transform_with_numba(
534-
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
535-
)
536-
return self.obj._constructor(
537-
result.ravel(), index=data.index, name=data.name
538-
)
539-
540-
func = com.get_cython_func(func) or func
541-
542-
if not isinstance(func, str):
543-
return self._transform_general(func, *args, **kwargs)
544-
545-
elif func not in base.transform_kernel_allowlist:
546-
msg = f"'{func}' is not a valid function name for transform(name)"
547-
raise ValueError(msg)
548-
elif func in base.cythonized_kernels or func in base.transformation_kernels:
549-
# cythonized transform or canned "agg+broadcast"
550-
return getattr(self, func)(*args, **kwargs)
551-
# If func is a reduction, we need to broadcast the
552-
# result to the whole group. Compute func result
553-
# and deal with possible broadcasting below.
554-
# Temporarily set observed for dealing with categoricals.
555-
with com.temp_setattr(self, "observed", True):
556-
result = getattr(self, func)(*args, **kwargs)
557-
return self._wrap_transform_fast_result(result)
529+
return self._transform(
530+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
531+
)
558532

559533
def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
560534
"""
@@ -586,6 +560,9 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
586560
result.name = self._selected_obj.name
587561
return result
588562

563+
def _can_use_transform_fast(self, result) -> bool:
564+
return True
565+
589566
def _wrap_transform_fast_result(self, result: Series) -> Series:
590567
"""
591568
fast version of transform, only applicable to
@@ -1334,43 +1311,14 @@ def _transform_general(self, func, *args, **kwargs):
13341311
@Substitution(klass="DataFrame")
13351312
@Appender(_transform_template)
13361313
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
1314+
return self._transform(
1315+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1316+
)
13371317

1338-
if maybe_use_numba(engine):
1339-
with group_selection_context(self):
1340-
data = self._selected_obj
1341-
result = self._transform_with_numba(
1342-
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
1343-
)
1344-
return self.obj._constructor(result, index=data.index, columns=data.columns)
1345-
1346-
# optimized transforms
1347-
func = com.get_cython_func(func) or func
1348-
1349-
if not isinstance(func, str):
1350-
return self._transform_general(func, *args, **kwargs)
1351-
1352-
elif func not in base.transform_kernel_allowlist:
1353-
msg = f"'{func}' is not a valid function name for transform(name)"
1354-
raise ValueError(msg)
1355-
elif func in base.cythonized_kernels or func in base.transformation_kernels:
1356-
# cythonized transformation or canned "reduction+broadcast"
1357-
return getattr(self, func)(*args, **kwargs)
1358-
# GH 30918
1359-
# Use _transform_fast only when we know func is an aggregation
1360-
if func in base.reduction_kernels:
1361-
# If func is a reduction, we need to broadcast the
1362-
# result to the whole group. Compute func result
1363-
# and deal with possible broadcasting below.
1364-
# Temporarily set observed for dealing with categoricals.
1365-
with com.temp_setattr(self, "observed", True):
1366-
result = getattr(self, func)(*args, **kwargs)
1367-
1368-
if isinstance(result, DataFrame) and result.columns.equals(
1369-
self._obj_with_exclusions.columns
1370-
):
1371-
return self._wrap_transform_fast_result(result)
1372-
1373-
return self._transform_general(func, *args, **kwargs)
1318+
def _can_use_transform_fast(self, result) -> bool:
1319+
return isinstance(result, DataFrame) and result.columns.equals(
1320+
self._obj_with_exclusions.columns
1321+
)
13741322

13751323
def _wrap_transform_fast_result(self, result: DataFrame) -> DataFrame:
13761324
"""

pandas/core/groupby/groupby.py

+54-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class providing the base-class of operations.
2929
Sequence,
3030
TypeVar,
3131
Union,
32+
cast,
3233
)
3334

3435
import numpy as np
@@ -104,7 +105,10 @@ class providing the base-class of operations.
104105
from pandas.core.internals.blocks import ensure_block_shape
105106
from pandas.core.series import Series
106107
from pandas.core.sorting import get_group_index_sorter
107-
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
108+
from pandas.core.util.numba_ import (
109+
NUMBA_FUNC_CACHE,
110+
maybe_use_numba,
111+
)
108112

109113
if TYPE_CHECKING:
110114
from typing import Literal
@@ -1398,8 +1402,55 @@ def _cython_transform(
13981402

13991403
return self._wrap_transformed_output(output)
14001404

1401-
def transform(self, func, *args, **kwargs):
1402-
raise AbstractMethodError(self)
1405+
@final
1406+
def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
1407+
1408+
if maybe_use_numba(engine):
1409+
# TODO: tests with self._selected_obj.ndim == 1 on DataFrameGroupBy
1410+
with group_selection_context(self):
1411+
data = self._selected_obj
1412+
df = data if data.ndim == 2 else data.to_frame()
1413+
result = self._transform_with_numba(
1414+
df, func, *args, engine_kwargs=engine_kwargs, **kwargs
1415+
)
1416+
if self.obj.ndim == 2:
1417+
return cast(DataFrame, self.obj)._constructor(
1418+
result, index=data.index, columns=data.columns
1419+
)
1420+
else:
1421+
return cast(Series, self.obj)._constructor(
1422+
result.ravel(), index=data.index, name=data.name
1423+
)
1424+
1425+
# optimized transforms
1426+
func = com.get_cython_func(func) or func
1427+
1428+
if not isinstance(func, str):
1429+
return self._transform_general(func, *args, **kwargs)
1430+
1431+
elif func not in base.transform_kernel_allowlist:
1432+
msg = f"'{func}' is not a valid function name for transform(name)"
1433+
raise ValueError(msg)
1434+
elif func in base.cythonized_kernels or func in base.transformation_kernels:
1435+
# cythonized transform or canned "agg+broadcast"
1436+
return getattr(self, func)(*args, **kwargs)
1437+
1438+
else:
1439+
# i.e. func in base.reduction_kernels
1440+
1441+
# GH#30918 Use _transform_fast only when we know func is an aggregation
1442+
# If func is a reduction, we need to broadcast the
1443+
# result to the whole group. Compute func result
1444+
# and deal with possible broadcasting below.
1445+
# Temporarily set observed for dealing with categoricals.
1446+
with com.temp_setattr(self, "observed", True):
1447+
result = getattr(self, func)(*args, **kwargs)
1448+
1449+
if self._can_use_transform_fast(result):
1450+
return self._wrap_transform_fast_result(result)
1451+
1452+
# only reached for DataFrameGroupBy
1453+
return self._transform_general(func, *args, **kwargs)
14031454

14041455
# -----------------------------------------------------------------
14051456
# Utilities

0 commit comments

Comments
 (0)