diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 57dad5a080358..5cb4bf0ef6ef4 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -814,6 +814,8 @@ Groupby/resample/rolling - Bug in :meth:`GroupBy.first` and :meth:`GroupBy.last` where None is not preserved in object dtype (:issue:`32800`) - Bug in :meth:`Rolling.min` and :meth:`Rolling.max`: Growing memory usage after multiple calls when using a fixed window (:issue:`30726`) - Bug in :meth:`GroupBy.agg`, :meth:`GroupBy.transform`, and :meth:`GroupBy.resample` where subclasses are not preserved (:issue:`28330`) +- Bug in :meth:`GroupBy.rolling.apply` ignores args and kwargs parameters (:issue:`33433`) + Reshaping ^^^^^^^^^ diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 166ab13344816..660fca61fd21c 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -1302,6 +1302,8 @@ def apply( use_numba_cache=engine == "numba", raw=raw, original_func=func, + args=args, + kwargs=kwargs, ) def _generate_cython_apply_func(self, args, kwargs, raw, offset, func): diff --git a/pandas/tests/window/test_apply.py b/pandas/tests/window/test_apply.py index 34cf0a3054889..bc38634da8941 100644 --- a/pandas/tests/window/test_apply.py +++ b/pandas/tests/window/test_apply.py @@ -4,7 +4,7 @@ from pandas.errors import NumbaUtilError import pandas.util._test_decorators as td -from pandas import DataFrame, Series, Timestamp, date_range +from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range import pandas._testing as tm @@ -139,3 +139,28 @@ def test_invalid_kwargs_nopython(): Series(range(1)).rolling(1).apply( lambda x: x, kwargs={"a": 1}, engine="numba", raw=True ) + + +@pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]]) +def test_rolling_apply_args_kwargs(args_kwargs): + # GH 33433 + def foo(x, par): + return np.sum(x + par) + + df = DataFrame({"gr": [1, 1], "a": [1, 2]}) + + idx = Index(["gr", "a"]) + expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx) + + result = df.rolling(1).apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1]) + tm.assert_frame_equal(result, expected) + + result = df.rolling(1).apply(foo, args=(10,)) + + midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None]) + expected = Series([11.0, 12.0], index=midx, name="a") + + gb_rolling = df.groupby("gr")["a"].rolling(1) + + result = gb_rolling.apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1]) + tm.assert_series_equal(result, expected)