Skip to content

Commit 9bc9ab2

Browse files
mproszewskasimonjayhawkins
authored andcommitted
Backport PR pandas-dev#33983 on branch 1.0.x (BUG: Use args and kwargs in Rolling.apply)
1 parent 6033909 commit 9bc9ab2

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

doc/source/whatsnew/v1.0.4.rst

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Fixed regressions
2525
- Fix to preserve the ability to index with the "nearest" method with xarray's CFTimeIndex, an :class:`Index` subclass (`pydata/xarray#3751 <https://github.com/pydata/xarray/issues/3751>`_, :issue:`32905`).
2626
- Fix regression in :meth:`DataFrame.describe` raising ``TypeError: unhashable type: 'dict'`` (:issue:`32409`)
2727
- Bug in :meth:`DataFrame.replace` casts columns to ``object`` dtype if items in ``to_replace`` not in values (:issue:`32988`)
28+
- Bug in :meth:`GroupBy.rolling.apply` ignores args and kwargs parameters (:issue:`33433`)
2829
-
2930

3031
.. _whatsnew_104.bug_fixes:

pandas/core/window/rolling.py

+2
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,8 @@ def apply(
13041304
name=func,
13051305
use_numba_cache=engine == "numba",
13061306
raw=raw,
1307+
args=args,
1308+
kwargs=kwargs,
13071309
)
13081310

13091311
def _generate_cython_apply_func(self, args, kwargs, raw, offset, func):

pandas/tests/window/test_apply.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pandas.util._test_decorators as td
55

6-
from pandas import DataFrame, Series, Timestamp, date_range
6+
from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range
77
import pandas._testing as tm
88

99

@@ -138,3 +138,28 @@ def test_invalid_kwargs_nopython():
138138
Series(range(1)).rolling(1).apply(
139139
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
140140
)
141+
142+
143+
@pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]])
144+
def test_rolling_apply_args_kwargs(args_kwargs):
145+
# GH 33433
146+
def foo(x, par):
147+
return np.sum(x + par)
148+
149+
df = DataFrame({"gr": [1, 1], "a": [1, 2]})
150+
151+
idx = Index(["gr", "a"])
152+
expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx)
153+
154+
result = df.rolling(1).apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1])
155+
tm.assert_frame_equal(result, expected)
156+
157+
result = df.rolling(1).apply(foo, args=(10,))
158+
159+
midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None])
160+
expected = Series([11.0, 12.0], index=midx, name="a")
161+
162+
gb_rolling = df.groupby("gr")["a"].rolling(1)
163+
164+
result = gb_rolling.apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1])
165+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)