Skip to content

Commit 67e06ea

Browse files
authored
REF: Make numba function cache globally accessible (#33621)
1 parent 37022be commit 67e06ea

File tree

6 files changed

+25
-19
lines changed

6 files changed

+25
-19
lines changed

pandas/core/groupby/generic.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from pandas.core.internals import BlockManager, make_block
7777
from pandas.core.series import Series
7878
from pandas.core.util.numba_ import (
79+
NUMBA_FUNC_CACHE,
7980
check_kwargs_and_nopython,
8081
get_jit_arguments,
8182
jit_user_function,
@@ -161,8 +162,6 @@ def pinner(cls):
161162
class SeriesGroupBy(GroupBy[Series]):
162163
_apply_whitelist = base.series_apply_whitelist
163164

164-
_numba_func_cache: Dict[Callable, Callable] = {}
165-
166165
def _iterate_slices(self) -> Iterable[Series]:
167166
yield self._selected_obj
168167

@@ -504,8 +503,9 @@ def _transform_general(
504503
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
505504
check_kwargs_and_nopython(kwargs, nopython)
506505
validate_udf(func)
507-
numba_func = self._numba_func_cache.get(
508-
func, jit_user_function(func, nopython, nogil, parallel)
506+
cache_key = (func, "groupby_transform")
507+
numba_func = NUMBA_FUNC_CACHE.get(
508+
cache_key, jit_user_function(func, nopython, nogil, parallel)
509509
)
510510

511511
klass = type(self._selected_obj)
@@ -516,8 +516,8 @@ def _transform_general(
516516
if engine == "numba":
517517
values, index = split_for_numba(group)
518518
res = numba_func(values, index, *args)
519-
if func not in self._numba_func_cache:
520-
self._numba_func_cache[func] = numba_func
519+
if cache_key not in NUMBA_FUNC_CACHE:
520+
NUMBA_FUNC_CACHE[cache_key] = numba_func
521521
else:
522522
res = func(group, *args, **kwargs)
523523

@@ -847,8 +847,6 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
847847

848848
_apply_whitelist = base.dataframe_apply_whitelist
849849

850-
_numba_func_cache: Dict[Callable, Callable] = {}
851-
852850
_agg_see_also_doc = dedent(
853851
"""
854852
See Also
@@ -1397,8 +1395,9 @@ def _transform_general(
13971395
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
13981396
check_kwargs_and_nopython(kwargs, nopython)
13991397
validate_udf(func)
1400-
numba_func = self._numba_func_cache.get(
1401-
func, jit_user_function(func, nopython, nogil, parallel)
1398+
cache_key = (func, "groupby_transform")
1399+
numba_func = NUMBA_FUNC_CACHE.get(
1400+
cache_key, jit_user_function(func, nopython, nogil, parallel)
14021401
)
14031402
else:
14041403
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
@@ -1409,8 +1408,8 @@ def _transform_general(
14091408
if engine == "numba":
14101409
values, index = split_for_numba(group)
14111410
res = numba_func(values, index, *args)
1412-
if func not in self._numba_func_cache:
1413-
self._numba_func_cache[func] = numba_func
1411+
if cache_key not in NUMBA_FUNC_CACHE:
1412+
NUMBA_FUNC_CACHE[cache_key] = numba_func
14141413
# Return the result as a DataFrame for concatenation later
14151414
res = DataFrame(res, index=group.index, columns=group.columns)
14161415
else:

pandas/core/util/numba_.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from pandas._typing import FrameOrSeries
99
from pandas.compat._optional import import_optional_dependency
1010

11+
NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = dict()
12+
1113

1214
def check_kwargs_and_nopython(
1315
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None

pandas/core/window/common.py

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _apply(
7878
performing the original function call on the grouped object.
7979
"""
8080
kwargs.pop("floor", None)
81+
kwargs.pop("original_func", None)
8182

8283
# TODO: can we de-duplicate with _dispatch?
8384
def f(x, name=name, *args):

pandas/core/window/rolling.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from pandas.core.base import DataError, PandasObject, SelectionMixin, ShallowMixin
3939
import pandas.core.common as com
4040
from pandas.core.indexes.api import Index, ensure_index
41+
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
4142
from pandas.core.window.common import (
4243
WindowGroupByMixin,
4344
_doc_template,
@@ -93,7 +94,6 @@ def __init__(
9394
self.win_freq = None
9495
self.axis = obj._get_axis_number(axis) if axis is not None else None
9596
self.validate()
96-
self._numba_func_cache: Dict[Optional[str], Callable] = dict()
9797

9898
@property
9999
def _constructor(self):
@@ -505,7 +505,7 @@ def calc(x):
505505
result = np.asarray(result)
506506

507507
if use_numba_cache:
508-
self._numba_func_cache[name] = func
508+
NUMBA_FUNC_CACHE[(kwargs["original_func"], "rolling_apply")] = func
509509

510510
if center:
511511
result = self._center_window(result, window)
@@ -1279,9 +1279,10 @@ def apply(
12791279
elif engine == "numba":
12801280
if raw is False:
12811281
raise ValueError("raw must be `True` when using the numba engine")
1282-
if func in self._numba_func_cache:
1282+
cache_key = (func, "rolling_apply")
1283+
if cache_key in NUMBA_FUNC_CACHE:
12831284
# Return an already compiled version of roll_apply if available
1284-
apply_func = self._numba_func_cache[func]
1285+
apply_func = NUMBA_FUNC_CACHE[cache_key]
12851286
else:
12861287
apply_func = generate_numba_apply_func(
12871288
args, kwargs, func, engine_kwargs
@@ -1298,6 +1299,7 @@ def apply(
12981299
name=func,
12991300
use_numba_cache=engine == "numba",
13001301
raw=raw,
1302+
original_func=func,
13011303
)
13021304

13031305
def _generate_cython_apply_func(self, args, kwargs, raw, offset, func):

pandas/tests/groupby/transform/test_numba.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pandas import DataFrame
66
import pandas._testing as tm
7+
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
78

89

910
@td.skip_if_no("numba", "0.46.0")
@@ -98,13 +99,13 @@ def func_2(values, index):
9899
expected = grouped.transform(lambda x: x + 1, engine="cython")
99100
tm.assert_equal(result, expected)
100101
# func_1 should be in the cache now
101-
assert func_1 in grouped._numba_func_cache
102+
assert (func_1, "groupby_transform") in NUMBA_FUNC_CACHE
102103

103104
# Add func_2 to the cache
104105
result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs)
105106
expected = grouped.transform(lambda x: x * 5, engine="cython")
106107
tm.assert_equal(result, expected)
107-
assert func_2 in grouped._numba_func_cache
108+
assert (func_2, "groupby_transform") in NUMBA_FUNC_CACHE
108109

109110
# Retest func_1 which should use the cache
110111
result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs)

pandas/tests/window/test_numba.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pandas import Series
77
import pandas._testing as tm
8+
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
89

910

1011
@td.skip_if_no("numba", "0.46.0")
@@ -59,7 +60,7 @@ def func_2(x):
5960
tm.assert_series_equal(result, expected)
6061

6162
# func_1 should be in the cache now
62-
assert func_1 in roll._numba_func_cache
63+
assert (func_1, "rolling_apply") in NUMBA_FUNC_CACHE
6364

6465
result = roll.apply(
6566
func_2, engine="numba", engine_kwargs=engine_kwargs, raw=True

0 commit comments

Comments
 (0)