Skip to content

Commit 0f8a242

Browse files
jbrockmendelJulianWgs
authored andcommitted
REF: share GroupBy.transform (pandas-dev#41308)
1 parent 3d10836 commit 0f8a242

File tree

2 files changed

+67
-68
lines changed

2 files changed

+67
-68
lines changed

pandas/core/groupby/generic.py

+13-65
Original file line numberDiff line numberDiff line change
@@ -526,35 +526,9 @@ def _aggregate_named(self, func, *args, **kwargs):
526526
@Substitution(klass="Series")
527527
@Appender(_transform_template)
528528
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
529-
530-
if maybe_use_numba(engine):
531-
with group_selection_context(self):
532-
data = self._selected_obj
533-
result = self._transform_with_numba(
534-
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
535-
)
536-
return self.obj._constructor(
537-
result.ravel(), index=data.index, name=data.name
538-
)
539-
540-
func = com.get_cython_func(func) or func
541-
542-
if not isinstance(func, str):
543-
return self._transform_general(func, *args, **kwargs)
544-
545-
elif func not in base.transform_kernel_allowlist:
546-
msg = f"'{func}' is not a valid function name for transform(name)"
547-
raise ValueError(msg)
548-
elif func in base.cythonized_kernels or func in base.transformation_kernels:
549-
# cythonized transform or canned "agg+broadcast"
550-
return getattr(self, func)(*args, **kwargs)
551-
# If func is a reduction, we need to broadcast the
552-
# result to the whole group. Compute func result
553-
# and deal with possible broadcasting below.
554-
# Temporarily set observed for dealing with categoricals.
555-
with com.temp_setattr(self, "observed", True):
556-
result = getattr(self, func)(*args, **kwargs)
557-
return self._wrap_transform_fast_result(result)
529+
return self._transform(
530+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
531+
)
558532

559533
def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
560534
"""
@@ -586,6 +560,9 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
586560
result.name = self._selected_obj.name
587561
return result
588562

563+
def _can_use_transform_fast(self, result) -> bool:
564+
return True
565+
589566
def _wrap_transform_fast_result(self, result: Series) -> Series:
590567
"""
591568
fast version of transform, only applicable to
@@ -1334,43 +1311,14 @@ def _transform_general(self, func, *args, **kwargs):
13341311
@Substitution(klass="DataFrame")
13351312
@Appender(_transform_template)
13361313
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
1314+
return self._transform(
1315+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1316+
)
13371317

1338-
if maybe_use_numba(engine):
1339-
with group_selection_context(self):
1340-
data = self._selected_obj
1341-
result = self._transform_with_numba(
1342-
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
1343-
)
1344-
return self.obj._constructor(result, index=data.index, columns=data.columns)
1345-
1346-
# optimized transforms
1347-
func = com.get_cython_func(func) or func
1348-
1349-
if not isinstance(func, str):
1350-
return self._transform_general(func, *args, **kwargs)
1351-
1352-
elif func not in base.transform_kernel_allowlist:
1353-
msg = f"'{func}' is not a valid function name for transform(name)"
1354-
raise ValueError(msg)
1355-
elif func in base.cythonized_kernels or func in base.transformation_kernels:
1356-
# cythonized transformation or canned "reduction+broadcast"
1357-
return getattr(self, func)(*args, **kwargs)
1358-
# GH 30918
1359-
# Use _transform_fast only when we know func is an aggregation
1360-
if func in base.reduction_kernels:
1361-
# If func is a reduction, we need to broadcast the
1362-
# result to the whole group. Compute func result
1363-
# and deal with possible broadcasting below.
1364-
# Temporarily set observed for dealing with categoricals.
1365-
with com.temp_setattr(self, "observed", True):
1366-
result = getattr(self, func)(*args, **kwargs)
1367-
1368-
if isinstance(result, DataFrame) and result.columns.equals(
1369-
self._obj_with_exclusions.columns
1370-
):
1371-
return self._wrap_transform_fast_result(result)
1372-
1373-
return self._transform_general(func, *args, **kwargs)
1318+
def _can_use_transform_fast(self, result) -> bool:
1319+
return isinstance(result, DataFrame) and result.columns.equals(
1320+
self._obj_with_exclusions.columns
1321+
)
13741322

13751323
def _wrap_transform_fast_result(self, result: DataFrame) -> DataFrame:
13761324
"""

pandas/core/groupby/groupby.py

+54-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class providing the base-class of operations.
2929
Sequence,
3030
TypeVar,
3131
Union,
32+
cast,
3233
)
3334

3435
import numpy as np
@@ -104,7 +105,10 @@ class providing the base-class of operations.
104105
from pandas.core.internals.blocks import ensure_block_shape
105106
from pandas.core.series import Series
106107
from pandas.core.sorting import get_group_index_sorter
107-
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
108+
from pandas.core.util.numba_ import (
109+
NUMBA_FUNC_CACHE,
110+
maybe_use_numba,
111+
)
108112

109113
from pandas.io.formats.format import repr_html_groupby
110114

@@ -1403,8 +1407,55 @@ def _cython_transform(
14031407

14041408
return self._wrap_transformed_output(output)
14051409

1406-
def transform(self, func, *args, **kwargs):
1407-
raise AbstractMethodError(self)
1410+
@final
1411+
def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
1412+
1413+
if maybe_use_numba(engine):
1414+
# TODO: tests with self._selected_obj.ndim == 1 on DataFrameGroupBy
1415+
with group_selection_context(self):
1416+
data = self._selected_obj
1417+
df = data if data.ndim == 2 else data.to_frame()
1418+
result = self._transform_with_numba(
1419+
df, func, *args, engine_kwargs=engine_kwargs, **kwargs
1420+
)
1421+
if self.obj.ndim == 2:
1422+
return cast(DataFrame, self.obj)._constructor(
1423+
result, index=data.index, columns=data.columns
1424+
)
1425+
else:
1426+
return cast(Series, self.obj)._constructor(
1427+
result.ravel(), index=data.index, name=data.name
1428+
)
1429+
1430+
# optimized transforms
1431+
func = com.get_cython_func(func) or func
1432+
1433+
if not isinstance(func, str):
1434+
return self._transform_general(func, *args, **kwargs)
1435+
1436+
elif func not in base.transform_kernel_allowlist:
1437+
msg = f"'{func}' is not a valid function name for transform(name)"
1438+
raise ValueError(msg)
1439+
elif func in base.cythonized_kernels or func in base.transformation_kernels:
1440+
# cythonized transform or canned "agg+broadcast"
1441+
return getattr(self, func)(*args, **kwargs)
1442+
1443+
else:
1444+
# i.e. func in base.reduction_kernels
1445+
1446+
# GH#30918 Use _transform_fast only when we know func is an aggregation
1447+
# If func is a reduction, we need to broadcast the
1448+
# result to the whole group. Compute func result
1449+
# and deal with possible broadcasting below.
1450+
# Temporarily set observed for dealing with categoricals.
1451+
with com.temp_setattr(self, "observed", True):
1452+
result = getattr(self, func)(*args, **kwargs)
1453+
1454+
if self._can_use_transform_fast(result):
1455+
return self._wrap_transform_fast_result(result)
1456+
1457+
# only reached for DataFrameGroupBy
1458+
return self._transform_general(func, *args, **kwargs)
14081459

14091460
# -----------------------------------------------------------------
14101461
# Utilities

0 commit comments

Comments
 (0)