diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 6a6abcf2d48fe..23bce184f6948 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -28,6 +28,7 @@ enhancement2 Other enhancements ^^^^^^^^^^^^^^^^^^ +- :class:`pandas.NamedAgg` now forwards any ``*args`` and ``**kwargs`` to calls of ``aggfunc`` (:issue:`58283`) - :class:`pandas.api.typing.FrozenList` is available for typing the outputs of :attr:`MultiIndex.names`, :attr:`MultiIndex.codes` and :attr:`MultiIndex.levels` (:issue:`58237`) - :class:`pandas.api.typing.SASReader` is available for typing the output of :func:`read_sas` (:issue:`55689`) - :func:`DataFrame.to_excel` now raises an ``UserWarning`` when the character count in a cell exceeds Excel's limitation of 32767 characters (:issue:`56954`) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index a20577e8d3df9..eeb29028cc100 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -108,7 +108,12 @@ ScalarResult = TypeVar("ScalarResult") -class NamedAgg(NamedTuple): +class _BaseNamedAgg(NamedTuple): + column: Hashable + aggfunc: AggScalar + + +class NamedAgg(_BaseNamedAgg): """ Helper for column specific aggregation with control over output column names. @@ -121,6 +126,10 @@ class NamedAgg(NamedTuple): aggfunc : function or str Function to apply to the provided column. If string, the name of a built-in pandas function. + *args : tuple, optional + Args passed to aggfunc. + **kwargs : dict, optional + Kwargs passed to aggfunc. Examples -------- @@ -134,8 +143,20 @@ class NamedAgg(NamedTuple): 2 1 12.0 """ - column: Hashable - aggfunc: AggScalar + def __new__(cls, column, aggfunc, *args, **kwargs): + if not isinstance(aggfunc, str): + aggfunc = cls._get_wrapped_aggfunc(aggfunc, *args, **kwargs) + self = _BaseNamedAgg.__new__(cls, column, aggfunc) + return self + + @staticmethod + def _get_wrapped_aggfunc(function, *initial_args, **initial_kwargs): + def wrapped_aggfunc(*new_args, **new_kwargs): + final_args = new_args + initial_args + final_kwargs = {**initial_kwargs, **new_kwargs} + return function(*final_args, **final_kwargs) + + return wrapped_aggfunc class SeriesGroupBy(GroupBy[Series]): diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 3362d6209af6d..39357fa2b9547 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -827,6 +827,34 @@ def test_agg_namedtuple(self): expected = df.groupby("A").agg(b=("B", "sum"), c=("B", "count")) tm.assert_frame_equal(result, expected) + def test_single_named_agg_with_args_and_kwargs(self): + df = DataFrame({"A": [0, 1, 2, 3], "B": [1, 2, 3, 4]}) + + def n_between(ser, low, high): + return ser.between(low, high).sum() + + result = df.groupby("A").agg(n_between=pd.NamedAgg("B", n_between, 0, high=2)) + expected = df.groupby("A").agg(n_between=("B", lambda x: x.between(0, 2).sum())) + tm.assert_frame_equal(result, expected) + + def test_multiple_named_agg_with_args_and_kwargs(self): + df = DataFrame({"A": [0, 1, 2, 3], "B": [1, 2, 3, 4]}) + + def n_between(ser, low, high): + return ser.between(low, high).sum() + + result = df.groupby("A").agg( + n_between01=pd.NamedAgg("B", n_between, 0, 1), + n_between13=pd.NamedAgg("B", n_between, 1, 3), + n_between02=pd.NamedAgg("B", n_between, 0, 2), + ) + expected = df.groupby("A").agg( + n_between01=("B", lambda x: x.between(0, 1).sum()), + n_between13=("B", lambda x: x.between(0, 3).sum()), + n_between02=("B", lambda x: x.between(0, 2).sum()), + ) + tm.assert_frame_equal(result, expected) + def test_mangled(self): df = DataFrame({"A": [0, 1], "B": [1, 2], "C": [3, 4]}) result = df.groupby("A").agg(b=("B", lambda x: 0), c=("C", lambda x: 1))