Skip to content

Commit e66e5c3

Browse files
authored
BUG: Use args and kwargs in Rolling.apply (#33983)
1 parent 2b67c65 commit e66e5c3

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

doc/source/whatsnew/v1.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,8 @@ Groupby/resample/rolling
816816
- Bug in :meth:`GroupBy.first` and :meth:`GroupBy.last` where None is not preserved in object dtype (:issue:`32800`)
817817
- Bug in :meth:`Rolling.min` and :meth:`Rolling.max`: Growing memory usage after multiple calls when using a fixed window (:issue:`30726`)
818818
- Bug in :meth:`GroupBy.agg`, :meth:`GroupBy.transform`, and :meth:`GroupBy.resample` where subclasses are not preserved (:issue:`28330`)
819+
- Bug in :meth:`GroupBy.rolling.apply` ignores args and kwargs parameters (:issue:`33433`)
820+
819821

820822
Reshaping
821823
^^^^^^^^^

pandas/core/window/rolling.py

+2
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,8 @@ def apply(
13021302
use_numba_cache=engine == "numba",
13031303
raw=raw,
13041304
original_func=func,
1305+
args=args,
1306+
kwargs=kwargs,
13051307
)
13061308

13071309
def _generate_cython_apply_func(self, args, kwargs, raw, offset, func):

pandas/tests/window/test_apply.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pandas.errors import NumbaUtilError
55
import pandas.util._test_decorators as td
66

7-
from pandas import DataFrame, Series, Timestamp, date_range
7+
from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range
88
import pandas._testing as tm
99

1010

@@ -139,3 +139,28 @@ def test_invalid_kwargs_nopython():
139139
Series(range(1)).rolling(1).apply(
140140
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
141141
)
142+
143+
144+
@pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]])
145+
def test_rolling_apply_args_kwargs(args_kwargs):
146+
# GH 33433
147+
def foo(x, par):
148+
return np.sum(x + par)
149+
150+
df = DataFrame({"gr": [1, 1], "a": [1, 2]})
151+
152+
idx = Index(["gr", "a"])
153+
expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx)
154+
155+
result = df.rolling(1).apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1])
156+
tm.assert_frame_equal(result, expected)
157+
158+
result = df.rolling(1).apply(foo, args=(10,))
159+
160+
midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None])
161+
expected = Series([11.0, 12.0], index=midx, name="a")
162+
163+
gb_rolling = df.groupby("gr")["a"].rolling(1)
164+
165+
result = gb_rolling.apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1])
166+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)