diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index c8e811ce82b1f..ad58232c81b23 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -169,6 +169,7 @@ Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ - Bug in :meth:`GroupBy.apply` raises ``ValueError`` when the ``by`` axis is not sorted and has duplicates and the applied ``func`` does not mutate passed in objects (:issue:`30667`) +- Bug in :meth:`DataFrameGroupby.transform` produces incorrect result with transformation functions (:issue:`30918`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 98cdcd0f2b6ee..27dd6e953c219 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1416,22 +1416,20 @@ def transform(self, func, *args, **kwargs): # cythonized transformation or canned "reduction+broadcast" return getattr(self, func)(*args, **kwargs) - # If func is a reduction, we need to broadcast the - # result to the whole group. Compute func result - # and deal with possible broadcasting below. - result = getattr(self, func)(*args, **kwargs) - - # a reduction transform - if not isinstance(result, DataFrame): - return self._transform_general(func, *args, **kwargs) - - obj = self._obj_with_exclusions - - # nuisance columns - if not result.columns.equals(obj.columns): - return self._transform_general(func, *args, **kwargs) - - return self._transform_fast(result, func) + # GH 30918 + # Use _transform_fast only when we know func is an aggregation + if func in base.reduction_kernels: + # If func is a reduction, we need to broadcast the + # result to the whole group. Compute func result + # and deal with possible broadcasting below. + result = getattr(self, func)(*args, **kwargs) + + if isinstance(result, DataFrame) and result.columns.equals( + self._obj_with_exclusions.columns + ): + return self._transform_fast(result, func) + + return self._transform_general(func, *args, **kwargs) def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame: """ diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index 6c05c4038a829..8967ef06f50fb 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -317,6 +317,32 @@ def test_dispatch_transform(tsframe): tm.assert_frame_equal(filled, expected) +def test_transform_transformation_func(transformation_func): + # GH 30918 + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "bar", "bar", "baz"], + "B": [1, 2, np.nan, 3, 3, np.nan, 4], + } + ) + + if transformation_func in ["pad", "backfill", "tshift", "corrwith", "cumcount"]: + # These transformation functions are not yet covered in this test + pytest.xfail("See GH 31269 and GH 31270") + elif transformation_func == "fillna": + test_op = lambda x: x.transform("fillna", value=0) + mock_op = lambda x: x.fillna(value=0) + else: + test_op = lambda x: x.transform(transformation_func) + mock_op = lambda x: getattr(x, transformation_func)() + + result = test_op(df.groupby("A")) + groups = [df[["B"]].iloc[:4], df[["B"]].iloc[4:6], df[["B"]].iloc[6:]] + expected = concat([mock_op(g) for g in groups]) + + tm.assert_frame_equal(result, expected) + + def test_transform_select_columns(df): f = lambda x: x.mean() result = df.groupby("A")[["C", "D"]].transform(f)