|
4 | 4 | from pandas.errors import NumbaUtilError
|
5 | 5 | import pandas.util._test_decorators as td
|
6 | 6 |
|
7 |
| -from pandas import DataFrame, Series, Timestamp, date_range |
| 7 | +from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range |
8 | 8 | import pandas._testing as tm
|
9 | 9 |
|
10 | 10 |
|
@@ -139,3 +139,28 @@ def test_invalid_kwargs_nopython():
|
139 | 139 | Series(range(1)).rolling(1).apply(
|
140 | 140 | lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
|
141 | 141 | )
|
| 142 | + |
| 143 | + |
| 144 | +@pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]]) |
| 145 | +def test_rolling_apply_args_kwargs(args_kwargs): |
| 146 | + # GH 33433 |
| 147 | + def foo(x, par): |
| 148 | + return np.sum(x + par) |
| 149 | + |
| 150 | + df = DataFrame({"gr": [1, 1], "a": [1, 2]}) |
| 151 | + |
| 152 | + idx = Index(["gr", "a"]) |
| 153 | + expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx) |
| 154 | + |
| 155 | + result = df.rolling(1).apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1]) |
| 156 | + tm.assert_frame_equal(result, expected) |
| 157 | + |
| 158 | + result = df.rolling(1).apply(foo, args=(10,)) |
| 159 | + |
| 160 | + midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None]) |
| 161 | + expected = Series([11.0, 12.0], index=midx, name="a") |
| 162 | + |
| 163 | + gb_rolling = df.groupby("gr")["a"].rolling(1) |
| 164 | + |
| 165 | + result = gb_rolling.apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1]) |
| 166 | + tm.assert_series_equal(result, expected) |
0 commit comments