diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index c007d4920cbe7..504de404b2509 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -76,6 +76,7 @@ from pandas.core.internals import BlockManager, make_block from pandas.core.series import Series from pandas.core.util.numba_ import ( + NUMBA_FUNC_CACHE, check_kwargs_and_nopython, get_jit_arguments, jit_user_function, @@ -161,8 +162,6 @@ def pinner(cls): class SeriesGroupBy(GroupBy[Series]): _apply_whitelist = base.series_apply_whitelist - _numba_func_cache: Dict[Callable, Callable] = {} - def _iterate_slices(self) -> Iterable[Series]: yield self._selected_obj @@ -504,8 +503,9 @@ def _transform_general( nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) - numba_func = self._numba_func_cache.get( - func, jit_user_function(func, nopython, nogil, parallel) + cache_key = (func, "groupby_transform") + numba_func = NUMBA_FUNC_CACHE.get( + cache_key, jit_user_function(func, nopython, nogil, parallel) ) klass = type(self._selected_obj) @@ -516,8 +516,8 @@ def _transform_general( if engine == "numba": values, index = split_for_numba(group) res = numba_func(values, index, *args) - if func not in self._numba_func_cache: - self._numba_func_cache[func] = numba_func + if cache_key not in NUMBA_FUNC_CACHE: + NUMBA_FUNC_CACHE[cache_key] = numba_func else: res = func(group, *args, **kwargs) @@ -847,8 +847,6 @@ class DataFrameGroupBy(GroupBy[DataFrame]): _apply_whitelist = base.dataframe_apply_whitelist - _numba_func_cache: Dict[Callable, Callable] = {} - _agg_see_also_doc = dedent( """ See Also @@ -1397,8 +1395,9 @@ def _transform_general( nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) - numba_func = self._numba_func_cache.get( - func, jit_user_function(func, nopython, nogil, parallel) + cache_key = (func, "groupby_transform") + numba_func = NUMBA_FUNC_CACHE.get( + cache_key, jit_user_function(func, nopython, nogil, parallel) ) else: fast_path, slow_path = self._define_paths(func, *args, **kwargs) @@ -1409,8 +1408,8 @@ def _transform_general( if engine == "numba": values, index = split_for_numba(group) res = numba_func(values, index, *args) - if func not in self._numba_func_cache: - self._numba_func_cache[func] = numba_func + if cache_key not in NUMBA_FUNC_CACHE: + NUMBA_FUNC_CACHE[cache_key] = numba_func # Return the result as a DataFrame for concatenation later res = DataFrame(res, index=group.index, columns=group.columns) else: diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index c5b27b937a05b..af24189adbc27 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -8,6 +8,8 @@ from pandas._typing import FrameOrSeries from pandas.compat._optional import import_optional_dependency +NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = dict() + def check_kwargs_and_nopython( kwargs: Optional[Dict] = None, nopython: Optional[bool] = None diff --git a/pandas/core/window/common.py b/pandas/core/window/common.py index 40f17126fa163..ebc67d0a0e819 100644 --- a/pandas/core/window/common.py +++ b/pandas/core/window/common.py @@ -78,6 +78,7 @@ def _apply( performing the original function call on the grouped object. """ kwargs.pop("floor", None) + kwargs.pop("original_func", None) # TODO: can we de-duplicate with _dispatch? def f(x, name=name, *args): diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 3fdf81c4bb570..7dfc210eab901 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -38,6 +38,7 @@ from pandas.core.base import DataError, PandasObject, SelectionMixin, ShallowMixin import pandas.core.common as com from pandas.core.indexes.api import Index, ensure_index +from pandas.core.util.numba_ import NUMBA_FUNC_CACHE from pandas.core.window.common import ( WindowGroupByMixin, _doc_template, @@ -93,7 +94,6 @@ def __init__( self.win_freq = None self.axis = obj._get_axis_number(axis) if axis is not None else None self.validate() - self._numba_func_cache: Dict[Optional[str], Callable] = dict() @property def _constructor(self): @@ -505,7 +505,7 @@ def calc(x): result = np.asarray(result) if use_numba_cache: - self._numba_func_cache[name] = func + NUMBA_FUNC_CACHE[(kwargs["original_func"], "rolling_apply")] = func if center: result = self._center_window(result, window) @@ -1278,9 +1278,10 @@ def apply( elif engine == "numba": if raw is False: raise ValueError("raw must be `True` when using the numba engine") - if func in self._numba_func_cache: + cache_key = (func, "rolling_apply") + if cache_key in NUMBA_FUNC_CACHE: # Return an already compiled version of roll_apply if available - apply_func = self._numba_func_cache[func] + apply_func = NUMBA_FUNC_CACHE[cache_key] else: apply_func = generate_numba_apply_func( args, kwargs, func, engine_kwargs @@ -1297,6 +1298,7 @@ def apply( name=func, use_numba_cache=engine == "numba", raw=raw, + original_func=func, ) def _generate_cython_apply_func(self, args, kwargs, raw, offset, func): diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 96078d0aa3662..28904b669ae56 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -4,6 +4,7 @@ from pandas import DataFrame import pandas._testing as tm +from pandas.core.util.numba_ import NUMBA_FUNC_CACHE @td.skip_if_no("numba", "0.46.0") @@ -98,13 +99,13 @@ def func_2(values, index): expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_equal(result, expected) # func_1 should be in the cache now - assert func_1 in grouped._numba_func_cache + assert (func_1, "groupby_transform") in NUMBA_FUNC_CACHE # Add func_2 to the cache result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x * 5, engine="cython") tm.assert_equal(result, expected) - assert func_2 in grouped._numba_func_cache + assert (func_2, "groupby_transform") in NUMBA_FUNC_CACHE # Retest func_1 which should use the cache result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index cc8aef1779b46..8ecf64b171df4 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -5,6 +5,7 @@ from pandas import Series import pandas._testing as tm +from pandas.core.util.numba_ import NUMBA_FUNC_CACHE @td.skip_if_no("numba", "0.46.0") @@ -59,7 +60,7 @@ def func_2(x): tm.assert_series_equal(result, expected) # func_1 should be in the cache now - assert func_1 in roll._numba_func_cache + assert (func_1, "rolling_apply") in NUMBA_FUNC_CACHE result = roll.apply( func_2, engine="numba", engine_kwargs=engine_kwargs, raw=True