|
3 | 3 |
|
4 | 4 | import pandas.util._test_decorators as td
|
5 | 5 |
|
6 |
| -from pandas import DataFrame, Series, Timestamp, date_range |
| 6 | +from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range |
7 | 7 | import pandas._testing as tm
|
8 | 8 |
|
9 | 9 |
|
@@ -138,3 +138,28 @@ def test_invalid_kwargs_nopython():
|
138 | 138 | Series(range(1)).rolling(1).apply(
|
139 | 139 | lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
|
140 | 140 | )
|
| 141 | + |
| 142 | + |
| 143 | +@pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]]) |
| 144 | +def test_rolling_apply_args_kwargs(args_kwargs): |
| 145 | + # GH 33433 |
| 146 | + def foo(x, par): |
| 147 | + return np.sum(x + par) |
| 148 | + |
| 149 | + df = DataFrame({"gr": [1, 1], "a": [1, 2]}) |
| 150 | + |
| 151 | + idx = Index(["gr", "a"]) |
| 152 | + expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx) |
| 153 | + |
| 154 | + result = df.rolling(1).apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1]) |
| 155 | + tm.assert_frame_equal(result, expected) |
| 156 | + |
| 157 | + result = df.rolling(1).apply(foo, args=(10,)) |
| 158 | + |
| 159 | + midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None]) |
| 160 | + expected = Series([11.0, 12.0], index=midx, name="a") |
| 161 | + |
| 162 | + gb_rolling = df.groupby("gr")["a"].rolling(1) |
| 163 | + |
| 164 | + result = gb_rolling.apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1]) |
| 165 | + tm.assert_series_equal(result, expected) |
0 commit comments