Skip to content

Commit 90f1901

Browse files
authored
BUG: Don't cache args during rolling/expanding.apply with numba engine (#42350)
1 parent 18d46b1 commit 90f1901

File tree

5 files changed

+59
-12
lines changed

5 files changed

+59
-12
lines changed

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ Plotting
189189

190190
Groupby/resample/rolling
191191
^^^^^^^^^^^^^^^^^^^^^^^^
192-
-
192+
- Bug in :meth:`Series.rolling.apply`, :meth:`DataFrame.rolling.apply`, :meth:`Series.expanding.apply` and :meth:`DataFrame.expanding.apply` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`42287`)
193193
-
194194

195195
Reshaping

pandas/core/window/numba_.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020

2121
def generate_numba_apply_func(
22-
args: tuple,
2322
kwargs: dict[str, Any],
2423
func: Callable[..., Scalar],
2524
engine_kwargs: dict[str, bool] | None,
@@ -36,8 +35,6 @@ def generate_numba_apply_func(
3635
3736
Parameters
3837
----------
39-
args : tuple
40-
*args to be passed into the function
4138
kwargs : dict
4239
**kwargs to be passed into the function
4340
func : function
@@ -62,7 +59,11 @@ def generate_numba_apply_func(
6259

6360
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
6461
def roll_apply(
65-
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int
62+
values: np.ndarray,
63+
begin: np.ndarray,
64+
end: np.ndarray,
65+
minimum_periods: int,
66+
*args: Any,
6667
) -> np.ndarray:
6768
result = np.empty(len(begin))
6869
for i in numba.prange(len(result)):
@@ -169,7 +170,6 @@ def ewma(
169170

170171

171172
def generate_numba_table_func(
172-
args: tuple,
173173
kwargs: dict[str, Any],
174174
func: Callable[..., np.ndarray],
175175
engine_kwargs: dict[str, bool] | None,
@@ -187,8 +187,6 @@ def generate_numba_table_func(
187187
188188
Parameters
189189
----------
190-
args : tuple
191-
*args to be passed into the function
192190
kwargs : dict
193191
**kwargs to be passed into the function
194192
func : function
@@ -213,7 +211,11 @@ def generate_numba_table_func(
213211

214212
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
215213
def roll_table(
216-
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int
214+
values: np.ndarray,
215+
begin: np.ndarray,
216+
end: np.ndarray,
217+
minimum_periods: int,
218+
*args: Any,
217219
):
218220
result = np.empty(values.shape)
219221
min_periods_mask = np.empty(values.shape)

pandas/core/window/rolling.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ def _apply(
481481
func: Callable[..., Any],
482482
name: str | None = None,
483483
numba_cache_key: tuple[Callable, str] | None = None,
484+
numba_args: tuple[Any, ...] = (),
484485
**kwargs,
485486
):
486487
"""
@@ -494,6 +495,8 @@ def _apply(
494495
name : str,
495496
numba_cache_key : tuple
496497
caching key to be used to store a compiled numba func
498+
numba_args : tuple
499+
args to be passed when func is a numba func
497500
**kwargs
498501
additional arguments for rolling function and window function
499502
@@ -521,7 +524,7 @@ def calc(x):
521524
center=self.center,
522525
closed=self.closed,
523526
)
524-
return func(x, start, end, min_periods)
527+
return func(x, start, end, min_periods, *numba_args)
525528

526529
with np.errstate(all="ignore"):
527530
if values.ndim > 1 and self.method == "single":
@@ -582,12 +585,14 @@ def _apply(
582585
func: Callable[..., Any],
583586
name: str | None = None,
584587
numba_cache_key: tuple[Callable, str] | None = None,
588+
numba_args: tuple[Any, ...] = (),
585589
**kwargs,
586590
) -> FrameOrSeries:
587591
result = super()._apply(
588592
func,
589593
name,
590594
numba_cache_key,
595+
numba_args,
591596
**kwargs,
592597
)
593598
# Reconstruct the resulting MultiIndex
@@ -968,6 +973,7 @@ def _apply(
968973
func: Callable[[np.ndarray, int, int], np.ndarray],
969974
name: str | None = None,
970975
numba_cache_key: tuple[Callable, str] | None = None,
976+
numba_args: tuple[Any, ...] = (),
971977
**kwargs,
972978
):
973979
"""
@@ -981,6 +987,8 @@ def _apply(
981987
name : str,
982988
use_numba_cache : tuple
983989
unused
990+
numba_args : tuple
991+
unused
984992
**kwargs
985993
additional arguments for scipy windows if necessary
986994
@@ -1158,18 +1166,20 @@ def apply(
11581166
raise ValueError("raw parameter must be `True` or `False`")
11591167

11601168
numba_cache_key = None
1169+
numba_args: tuple[Any, ...] = ()
11611170
if maybe_use_numba(engine):
11621171
if raw is False:
11631172
raise ValueError("raw must be `True` when using the numba engine")
11641173
caller_name = type(self).__name__
1174+
numba_args = args
11651175
if self.method == "single":
11661176
apply_func = generate_numba_apply_func(
1167-
args, kwargs, func, engine_kwargs, caller_name
1177+
kwargs, func, engine_kwargs, caller_name
11681178
)
11691179
numba_cache_key = (func, f"{caller_name}_apply_single")
11701180
else:
11711181
apply_func = generate_numba_table_func(
1172-
args, kwargs, func, engine_kwargs, f"{caller_name}_apply"
1182+
kwargs, func, engine_kwargs, f"{caller_name}_apply"
11731183
)
11741184
numba_cache_key = (func, f"{caller_name}_apply_table")
11751185
elif engine in ("cython", None):
@@ -1182,6 +1192,7 @@ def apply(
11821192
return self._apply(
11831193
apply_func,
11841194
numba_cache_key=numba_cache_key,
1195+
numba_args=numba_args,
11851196
)
11861197

11871198
def _generate_cython_apply_func(

pandas/tests/window/conftest.py

+6
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ def min_periods(request):
8484
return request.param
8585

8686

87+
@pytest.fixture(params=["single", "table"])
88+
def method(request):
89+
"""method keyword in rolling/expanding/ewm constructor"""
90+
return request.param
91+
92+
8793
@pytest.fixture(params=[True, False])
8894
def parallel(request):
8995
"""parallel keyword argument for numba.jit"""

pandas/tests/window/test_numba.py

+28
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,34 @@ def func_2(x):
121121
expected = roll.apply(func_1, engine="cython", raw=True)
122122
tm.assert_series_equal(result, expected)
123123

124+
@pytest.mark.parametrize(
125+
"window,window_kwargs",
126+
[
127+
["rolling", {"window": 3, "min_periods": 0}],
128+
["expanding", {}],
129+
],
130+
)
131+
def test_dont_cache_args(
132+
self, window, window_kwargs, nogil, parallel, nopython, method
133+
):
134+
# GH 42287
135+
136+
def add(values, x):
137+
return np.sum(values) + x
138+
139+
df = DataFrame({"value": [0, 0, 0]})
140+
result = getattr(df, window)(**window_kwargs).apply(
141+
add, raw=True, engine="numba", args=(1,)
142+
)
143+
expected = DataFrame({"value": [1.0, 1.0, 1.0]})
144+
tm.assert_frame_equal(result, expected)
145+
146+
result = getattr(df, window)(**window_kwargs).apply(
147+
add, raw=True, engine="numba", args=(2,)
148+
)
149+
expected = DataFrame({"value": [2.0, 2.0, 2.0]})
150+
tm.assert_frame_equal(result, expected)
151+
124152

125153
@td.skip_if_no("numba", "0.46.0")
126154
class TestEWMMean:

0 commit comments

Comments
 (0)