Skip to content

Commit 57bb165

Browse files
authored
PERF: Try fast/slow paths only once in DataFrameGroupby.transform (#42195)
1 parent 4653b6a commit 57bb165

File tree

2 files changed

+45
-27
lines changed

2 files changed

+45
-27
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ Deprecations
160160
Performance improvements
161161
~~~~~~~~~~~~~~~~~~~~~~~~
162162
- Performance improvement in :meth:`.GroupBy.sample`, especially when ``weights`` argument provided (:issue:`34483`)
163-
-
163+
- Performance improvement in :meth:`.GroupBy.transform` for user-defined functions (:issue:`41598`)
164164

165165
.. ---------------------------------------------------------------------------
166166

pandas/core/groupby/generic.py

+44-26
Original file line numberDiff line numberDiff line change
@@ -1308,42 +1308,35 @@ def _transform_general(self, func, *args, **kwargs):
13081308
gen = self.grouper.get_iterator(obj, axis=self.axis)
13091309
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
13101310

1311-
for name, group in gen:
1312-
if group.size == 0:
1313-
continue
1311+
# Determine whether to use slow or fast path by evaluating on the first group.
1312+
# Need to handle the case of an empty generator and process the result so that
1313+
# it does not need to be computed again.
1314+
try:
1315+
name, group = next(gen)
1316+
except StopIteration:
1317+
pass
1318+
else:
13141319
object.__setattr__(group, "name", name)
1315-
1316-
# Try slow path and fast path.
13171320
try:
13181321
path, res = self._choose_path(fast_path, slow_path, group)
13191322
except TypeError:
13201323
return self._transform_item_by_item(obj, fast_path)
13211324
except ValueError as err:
13221325
msg = "transform must return a scalar value for each group"
13231326
raise ValueError(msg) from err
1324-
1325-
if isinstance(res, Series):
1326-
1327-
# we need to broadcast across the
1328-
# other dimension; this will preserve dtypes
1329-
# GH14457
1330-
if res.index.is_(obj.index):
1331-
r = concat([res] * len(group.columns), axis=1)
1332-
r.columns = group.columns
1333-
r.index = group.index
1334-
else:
1335-
r = self.obj._constructor(
1336-
np.concatenate([res.values] * len(group.index)).reshape(
1337-
group.shape
1338-
),
1339-
columns=group.columns,
1340-
index=group.index,
1341-
)
1342-
1343-
applied.append(r)
1344-
else:
1327+
if group.size > 0:
1328+
res = _wrap_transform_general_frame(self.obj, group, res)
13451329
applied.append(res)
13461330

1331+
# Compute and process with the remaining groups
1332+
for name, group in gen:
1333+
if group.size == 0:
1334+
continue
1335+
object.__setattr__(group, "name", name)
1336+
res = path(group)
1337+
res = _wrap_transform_general_frame(self.obj, group, res)
1338+
applied.append(res)
1339+
13471340
concat_index = obj.columns if self.axis == 0 else obj.index
13481341
other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1
13491342
concatenated = concat(applied, axis=self.axis, verify_integrity=False)
@@ -1853,3 +1846,28 @@ def func(df):
18531846
return self._python_apply_general(func, self._obj_with_exclusions)
18541847

18551848
boxplot = boxplot_frame_groupby
1849+
1850+
1851+
def _wrap_transform_general_frame(
1852+
obj: DataFrame, group: DataFrame, res: DataFrame | Series
1853+
) -> DataFrame:
1854+
from pandas import concat
1855+
1856+
if isinstance(res, Series):
1857+
# we need to broadcast across the
1858+
# other dimension; this will preserve dtypes
1859+
# GH14457
1860+
if res.index.is_(obj.index):
1861+
res_frame = concat([res] * len(group.columns), axis=1)
1862+
res_frame.columns = group.columns
1863+
res_frame.index = group.index
1864+
else:
1865+
res_frame = obj._constructor(
1866+
np.concatenate([res.values] * len(group.index)).reshape(group.shape),
1867+
columns=group.columns,
1868+
index=group.index,
1869+
)
1870+
assert isinstance(res_frame, DataFrame)
1871+
return res_frame
1872+
else:
1873+
return res

0 commit comments

Comments
 (0)