diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 24f307f23f435..06f89ba62c3e4 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -189,7 +189,7 @@ Plotting Groupby/resample/rolling ^^^^^^^^^^^^^^^^^^^^^^^^ -- +- Bug in :meth:`Series.rolling.apply`, :meth:`DataFrame.rolling.apply`, :meth:`Series.expanding.apply` and :meth:`DataFrame.expanding.apply` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`42287`) - Reshaping diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index d00be0ea840a8..9d9376e8ba43d 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -19,7 +19,6 @@ def generate_numba_apply_func( - args: tuple, kwargs: dict[str, Any], func: Callable[..., Scalar], engine_kwargs: dict[str, bool] | None, @@ -36,8 +35,6 @@ def generate_numba_apply_func( Parameters ---------- - args : tuple - *args to be passed into the function kwargs : dict **kwargs to be passed into the function func : function @@ -62,7 +59,11 @@ def generate_numba_apply_func( @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def roll_apply( - values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int + values: np.ndarray, + begin: np.ndarray, + end: np.ndarray, + minimum_periods: int, + *args: Any, ) -> np.ndarray: result = np.empty(len(begin)) for i in numba.prange(len(result)): @@ -169,7 +170,6 @@ def ewma( def generate_numba_table_func( - args: tuple, kwargs: dict[str, Any], func: Callable[..., np.ndarray], engine_kwargs: dict[str, bool] | None, @@ -187,8 +187,6 @@ def generate_numba_table_func( Parameters ---------- - args : tuple - *args to be passed into the function kwargs : dict **kwargs to be passed into the function func : function @@ -213,7 +211,11 @@ def generate_numba_table_func( @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def roll_table( - values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int + values: np.ndarray, + begin: np.ndarray, + end: np.ndarray, + minimum_periods: int, + *args: Any, ): result = np.empty(values.shape) min_periods_mask = np.empty(values.shape) diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 2d5f148a6437a..44784fbc95acd 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -482,6 +482,7 @@ def _apply( func: Callable[..., Any], name: str | None = None, numba_cache_key: tuple[Callable, str] | None = None, + numba_args: tuple[Any, ...] = (), **kwargs, ): """ @@ -495,6 +496,8 @@ def _apply( name : str, numba_cache_key : tuple caching key to be used to store a compiled numba func + numba_args : tuple + args to be passed when func is a numba func **kwargs additional arguments for rolling function and window function @@ -522,7 +525,7 @@ def calc(x): center=self.center, closed=self.closed, ) - return func(x, start, end, min_periods) + return func(x, start, end, min_periods, *numba_args) with np.errstate(all="ignore"): if values.ndim > 1 and self.method == "single": @@ -583,12 +586,14 @@ def _apply( func: Callable[..., Any], name: str | None = None, numba_cache_key: tuple[Callable, str] | None = None, + numba_args: tuple[Any, ...] = (), **kwargs, ) -> FrameOrSeries: result = super()._apply( func, name, numba_cache_key, + numba_args, **kwargs, ) # Reconstruct the resulting MultiIndex @@ -969,6 +974,7 @@ def _apply( func: Callable[[np.ndarray, int, int], np.ndarray], name: str | None = None, numba_cache_key: tuple[Callable, str] | None = None, + numba_args: tuple[Any, ...] = (), **kwargs, ): """ @@ -982,6 +988,8 @@ def _apply( name : str, use_numba_cache : tuple unused + numba_args : tuple + unused **kwargs additional arguments for scipy windows if necessary @@ -1159,18 +1167,20 @@ def apply( raise ValueError("raw parameter must be `True` or `False`") numba_cache_key = None + numba_args: tuple[Any, ...] = () if maybe_use_numba(engine): if raw is False: raise ValueError("raw must be `True` when using the numba engine") caller_name = type(self).__name__ + numba_args = args if self.method == "single": apply_func = generate_numba_apply_func( - args, kwargs, func, engine_kwargs, caller_name + kwargs, func, engine_kwargs, caller_name ) numba_cache_key = (func, f"{caller_name}_apply_single") else: apply_func = generate_numba_table_func( - args, kwargs, func, engine_kwargs, f"{caller_name}_apply" + kwargs, func, engine_kwargs, f"{caller_name}_apply" ) numba_cache_key = (func, f"{caller_name}_apply_table") elif engine in ("cython", None): @@ -1183,6 +1193,7 @@ def apply( return self._apply( apply_func, numba_cache_key=numba_cache_key, + numba_args=numba_args, ) def _generate_cython_apply_func( diff --git a/pandas/tests/window/conftest.py b/pandas/tests/window/conftest.py index 5382f5f9202c0..30073bd55531f 100644 --- a/pandas/tests/window/conftest.py +++ b/pandas/tests/window/conftest.py @@ -84,6 +84,12 @@ def min_periods(request): return request.param +@pytest.fixture(params=["single", "table"]) +def method(request): + """method keyword in rolling/expanding/ewm constructor""" + return request.param + + @pytest.fixture(params=[True, False]) def parallel(request): """parallel keyword argument for numba.jit""" diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index b79c367d482ae..5bc27436fd1d7 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -121,6 +121,34 @@ def func_2(x): expected = roll.apply(func_1, engine="cython", raw=True) tm.assert_series_equal(result, expected) + @pytest.mark.parametrize( + "window,window_kwargs", + [ + ["rolling", {"window": 3, "min_periods": 0}], + ["expanding", {}], + ], + ) + def test_dont_cache_args( + self, window, window_kwargs, nogil, parallel, nopython, method + ): + # GH 42287 + + def add(values, x): + return np.sum(values) + x + + df = DataFrame({"value": [0, 0, 0]}) + result = getattr(df, window)(**window_kwargs).apply( + add, raw=True, engine="numba", args=(1,) + ) + expected = DataFrame({"value": [1.0, 1.0, 1.0]}) + tm.assert_frame_equal(result, expected) + + result = getattr(df, window)(**window_kwargs).apply( + add, raw=True, engine="numba", args=(2,) + ) + expected = DataFrame({"value": [2.0, 2.0, 2.0]}) + tm.assert_frame_equal(result, expected) + @td.skip_if_no("numba", "0.46.0") class TestEWMMean: