From 95966c93f066965b6f1f2a16c1e1d6837e4d84dd Mon Sep 17 00:00:00 2001 From: Nicholas Tan Date: Sat, 6 Apr 2024 14:53:52 +1100 Subject: [PATCH] BUG: fix aggregation when using udf with kwarg --- pandas/core/groupby/generic.py | 4 ++-- .../tests/groupby/aggregate/test_aggregate.py | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 0a048d11d0b4d..61956ffcbb0a2 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1546,14 +1546,14 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) if self._grouper.nkeys > 1: # test_groupby_as_index_series_scalar gets here with 'not self.as_index' return self._python_agg_general(func, *args, **kwargs) - elif args or kwargs: + elif (args or kwargs) and (len(self._obj_with_exclusions.columns) == 1): # test_pass_args_kwargs gets here (with and without as_index) # can't return early result = self._aggregate_frame(func, *args, **kwargs) else: # try to treat as if we are passing a list - gba = GroupByApply(self, [func], args=(), kwargs={}) + gba = GroupByApply(self, [func], args=args, kwargs=kwargs) try: result = gba.agg() diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 2b9df1b7079da..42a13f6a769da 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -1663,3 +1663,24 @@ def func(x): msg = "length must not be 0" with pytest.raises(ValueError, match=msg): df.groupby("A", observed=False).agg(func) + + +def test_aggregation_of_UDF_with_kwargs(): + df = DataFrame( + { + "A": [1, 2, 3, 4, 5], + "B": [10, 20, 30, 40, 50], + "groupby1": ["diamond", "diamond", "spade", "spade", "spade"], + } + ) + + def user_defined_aggfunc(df, addition=0): + return np.mean(df) + addition + + out_df = df.groupby(by="groupby1").agg(func=user_defined_aggfunc, addition=5) + expected_df = DataFrame( + {"A": [6.5, 9.0], "B": [20.0, 45.0]}, + index=Index(["diamond", "spade"], name="groupby1"), + ) + + tm.assert_frame_equal(out_df, expected_df)