diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 5b09b62fa9e88..495322dada350 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -159,7 +159,7 @@ Deprecations Performance improvements ~~~~~~~~~~~~~~~~~~~~~~~~ - Performance improvement in :meth:`.GroupBy.sample`, especially when ``weights`` argument provided (:issue:`34483`) -- +- Performance improvement in :meth:`.GroupBy.transform` for user-defined functions (:issue:`41598`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 382cd0e178e15..88d1baae86467 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1308,12 +1308,15 @@ def _transform_general(self, func, *args, **kwargs): gen = self.grouper.get_iterator(obj, axis=self.axis) fast_path, slow_path = self._define_paths(func, *args, **kwargs) - for name, group in gen: - if group.size == 0: - continue + # Determine whether to use slow or fast path by evaluating on the first group. + # Need to handle the case of an empty generator and process the result so that + # it does not need to be computed again. + try: + name, group = next(gen) + except StopIteration: + pass + else: object.__setattr__(group, "name", name) - - # Try slow path and fast path. try: path, res = self._choose_path(fast_path, slow_path, group) except TypeError: @@ -1321,29 +1324,19 @@ def _transform_general(self, func, *args, **kwargs): except ValueError as err: msg = "transform must return a scalar value for each group" raise ValueError(msg) from err - - if isinstance(res, Series): - - # we need to broadcast across the - # other dimension; this will preserve dtypes - # GH14457 - if res.index.is_(obj.index): - r = concat([res] * len(group.columns), axis=1) - r.columns = group.columns - r.index = group.index - else: - r = self.obj._constructor( - np.concatenate([res.values] * len(group.index)).reshape( - group.shape - ), - columns=group.columns, - index=group.index, - ) - - applied.append(r) - else: + if group.size > 0: + res = _wrap_transform_general_frame(self.obj, group, res) applied.append(res) + # Compute and process with the remaining groups + for name, group in gen: + if group.size == 0: + continue + object.__setattr__(group, "name", name) + res = path(group) + res = _wrap_transform_general_frame(self.obj, group, res) + applied.append(res) + concat_index = obj.columns if self.axis == 0 else obj.index other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1 concatenated = concat(applied, axis=self.axis, verify_integrity=False) @@ -1853,3 +1846,28 @@ def func(df): return self._python_apply_general(func, self._obj_with_exclusions) boxplot = boxplot_frame_groupby + + +def _wrap_transform_general_frame( + obj: DataFrame, group: DataFrame, res: DataFrame | Series +) -> DataFrame: + from pandas import concat + + if isinstance(res, Series): + # we need to broadcast across the + # other dimension; this will preserve dtypes + # GH14457 + if res.index.is_(obj.index): + res_frame = concat([res] * len(group.columns), axis=1) + res_frame.columns = group.columns + res_frame.index = group.index + else: + res_frame = obj._constructor( + np.concatenate([res.values] * len(group.index)).reshape(group.shape), + columns=group.columns, + index=group.index, + ) + assert isinstance(res_frame, DataFrame) + return res_frame + else: + return res