Skip to content

BUG: Don't cache args during rolling/expanding.apply with numba engine #42350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ Plotting

Groupby/resample/rolling
^^^^^^^^^^^^^^^^^^^^^^^^
-
- Bug in :meth:`Series.rolling.apply`, :meth:`DataFrame.rolling.apply`, :meth:`Series.expanding.apply` and :meth:`DataFrame.expanding.apply` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`42287`)
-

Reshaping
Expand Down
18 changes: 10 additions & 8 deletions pandas/core/window/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


def generate_numba_apply_func(
args: tuple,
kwargs: dict[str, Any],
func: Callable[..., Scalar],
engine_kwargs: dict[str, bool] | None,
Expand All @@ -36,8 +35,6 @@ def generate_numba_apply_func(

Parameters
----------
args : tuple
*args to be passed into the function
kwargs : dict
**kwargs to be passed into the function
func : function
Expand All @@ -62,7 +59,11 @@ def generate_numba_apply_func(

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def roll_apply(
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
minimum_periods: int,
*args: Any,
) -> np.ndarray:
result = np.empty(len(begin))
for i in numba.prange(len(result)):
Expand Down Expand Up @@ -169,7 +170,6 @@ def ewma(


def generate_numba_table_func(
args: tuple,
kwargs: dict[str, Any],
func: Callable[..., np.ndarray],
engine_kwargs: dict[str, bool] | None,
Expand All @@ -187,8 +187,6 @@ def generate_numba_table_func(

Parameters
----------
args : tuple
*args to be passed into the function
kwargs : dict
**kwargs to be passed into the function
func : function
Expand All @@ -213,7 +211,11 @@ def generate_numba_table_func(

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def roll_table(
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
minimum_periods: int,
*args: Any,
):
result = np.empty(values.shape)
min_periods_mask = np.empty(values.shape)
Expand Down
17 changes: 14 additions & 3 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def _apply(
func: Callable[..., Any],
name: str | None = None,
numba_cache_key: tuple[Callable, str] | None = None,
numba_args: tuple[Any, ...] = (),
**kwargs,
):
"""
Expand All @@ -495,6 +496,8 @@ def _apply(
name : str,
numba_cache_key : tuple
caching key to be used to store a compiled numba func
numba_args : tuple
args to be passed when func is a numba func
**kwargs
additional arguments for rolling function and window function

Expand Down Expand Up @@ -522,7 +525,7 @@ def calc(x):
center=self.center,
closed=self.closed,
)
return func(x, start, end, min_periods)
return func(x, start, end, min_periods, *numba_args)

with np.errstate(all="ignore"):
if values.ndim > 1 and self.method == "single":
Expand Down Expand Up @@ -583,12 +586,14 @@ def _apply(
func: Callable[..., Any],
name: str | None = None,
numba_cache_key: tuple[Callable, str] | None = None,
numba_args: tuple[Any, ...] = (),
**kwargs,
) -> FrameOrSeries:
result = super()._apply(
func,
name,
numba_cache_key,
numba_args,
**kwargs,
)
# Reconstruct the resulting MultiIndex
Expand Down Expand Up @@ -969,6 +974,7 @@ def _apply(
func: Callable[[np.ndarray, int, int], np.ndarray],
name: str | None = None,
numba_cache_key: tuple[Callable, str] | None = None,
numba_args: tuple[Any, ...] = (),
**kwargs,
):
"""
Expand All @@ -982,6 +988,8 @@ def _apply(
name : str,
use_numba_cache : tuple
unused
numba_args : tuple
unused
**kwargs
additional arguments for scipy windows if necessary

Expand Down Expand Up @@ -1159,18 +1167,20 @@ def apply(
raise ValueError("raw parameter must be `True` or `False`")

numba_cache_key = None
numba_args: tuple[Any, ...] = ()
if maybe_use_numba(engine):
if raw is False:
raise ValueError("raw must be `True` when using the numba engine")
caller_name = type(self).__name__
numba_args = args
if self.method == "single":
apply_func = generate_numba_apply_func(
args, kwargs, func, engine_kwargs, caller_name
kwargs, func, engine_kwargs, caller_name
)
numba_cache_key = (func, f"{caller_name}_apply_single")
else:
apply_func = generate_numba_table_func(
args, kwargs, func, engine_kwargs, f"{caller_name}_apply"
kwargs, func, engine_kwargs, f"{caller_name}_apply"
)
numba_cache_key = (func, f"{caller_name}_apply_table")
elif engine in ("cython", None):
Expand All @@ -1183,6 +1193,7 @@ def apply(
return self._apply(
apply_func,
numba_cache_key=numba_cache_key,
numba_args=numba_args,
)

def _generate_cython_apply_func(
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/window/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def min_periods(request):
return request.param


@pytest.fixture(params=["single", "table"])
def method(request):
"""method keyword in rolling/expanding/ewm constructor"""
return request.param


@pytest.fixture(params=[True, False])
def parallel(request):
"""parallel keyword argument for numba.jit"""
Expand Down
28 changes: 28 additions & 0 deletions pandas/tests/window/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,34 @@ def func_2(x):
expected = roll.apply(func_1, engine="cython", raw=True)
tm.assert_series_equal(result, expected)

@pytest.mark.parametrize(
"window,window_kwargs",
[
["rolling", {"window": 3, "min_periods": 0}],
["expanding", {}],
],
)
def test_dont_cache_args(
self, window, window_kwargs, nogil, parallel, nopython, method
):
# GH 42287

def add(values, x):
return np.sum(values) + x

df = DataFrame({"value": [0, 0, 0]})
result = getattr(df, window)(**window_kwargs).apply(
add, raw=True, engine="numba", args=(1,)
)
expected = DataFrame({"value": [1.0, 1.0, 1.0]})
tm.assert_frame_equal(result, expected)

result = getattr(df, window)(**window_kwargs).apply(
add, raw=True, engine="numba", args=(2,)
)
expected = DataFrame({"value": [2.0, 2.0, 2.0]})
tm.assert_frame_equal(result, expected)


@td.skip_if_no("numba", "0.46.0")
class TestEWMMean:
Expand Down