Skip to content

Commit d149f41

Browse files
authored
Create numba helper function for jitting + generating cache key (#33910)
1 parent 592db7d commit d149f41

File tree

3 files changed

+48
-26
lines changed

3 files changed

+48
-26
lines changed

pandas/core/groupby/generic.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,8 @@
7878
from pandas.core.series import Series
7979
from pandas.core.util.numba_ import (
8080
NUMBA_FUNC_CACHE,
81-
check_kwargs_and_nopython,
82-
get_jit_arguments,
83-
jit_user_function,
81+
generate_numba_func,
8482
split_for_numba,
85-
validate_udf,
8683
)
8784

8885
from pandas.plotting import boxplot_frame_groupby
@@ -493,12 +490,8 @@ def _transform_general(
493490
"""
494491

495492
if engine == "numba":
496-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
497-
check_kwargs_and_nopython(kwargs, nopython)
498-
validate_udf(func)
499-
cache_key = (func, "groupby_transform")
500-
numba_func = NUMBA_FUNC_CACHE.get(
501-
cache_key, jit_user_function(func, nopython, nogil, parallel)
493+
numba_func, cache_key = generate_numba_func(
494+
func, engine_kwargs, kwargs, "groupby_transform"
502495
)
503496

504497
klass = type(self._selected_obj)
@@ -1377,12 +1370,8 @@ def _transform_general(
13771370
obj = self._obj_with_exclusions
13781371
gen = self.grouper.get_iterator(obj, axis=self.axis)
13791372
if engine == "numba":
1380-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
1381-
check_kwargs_and_nopython(kwargs, nopython)
1382-
validate_udf(func)
1383-
cache_key = (func, "groupby_transform")
1384-
numba_func = NUMBA_FUNC_CACHE.get(
1385-
cache_key, jit_user_function(func, nopython, nogil, parallel)
1373+
numba_func, cache_key = generate_numba_func(
1374+
func, engine_kwargs, kwargs, "groupby_transform"
13861375
)
13871376
else:
13881377
fast_path, slow_path = self._define_paths(func, *args, **kwargs)

pandas/core/groupby/ops.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,8 @@
5656
)
5757
from pandas.core.util.numba_ import (
5858
NUMBA_FUNC_CACHE,
59-
check_kwargs_and_nopython,
60-
get_jit_arguments,
61-
jit_user_function,
59+
generate_numba_func,
6260
split_for_numba,
63-
validate_udf,
6461
)
6562

6663

@@ -689,12 +686,8 @@ def _aggregate_series_pure_python(
689686
):
690687

691688
if engine == "numba":
692-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
693-
check_kwargs_and_nopython(kwargs, nopython)
694-
validate_udf(func)
695-
cache_key = (func, "groupby_agg")
696-
numba_func = NUMBA_FUNC_CACHE.get(
697-
cache_key, jit_user_function(func, nopython, nogil, parallel)
689+
numba_func, cache_key = generate_numba_func(
690+
func, engine_kwargs, kwargs, "groupby_agg"
698691
)
699692

700693
group_index, _, ngroups = self.group_info

pandas/core/util/numba_.py

+40
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,43 @@ def f(values, index, ...):
167167
f"The first {min_number_args} arguments to {func.__name__} must be "
168168
f"{expected_args}"
169169
)
170+
171+
172+
def generate_numba_func(
173+
func: Callable,
174+
engine_kwargs: Optional[Dict[str, bool]],
175+
kwargs: dict,
176+
cache_key_str: str,
177+
) -> Tuple[Callable, Tuple[Callable, str]]:
178+
"""
179+
Return a JITed function and cache key for the NUMBA_FUNC_CACHE
180+
181+
This _may_ be specific to groupby (as it's only used there currently).
182+
183+
Parameters
184+
----------
185+
func : function
186+
user defined function
187+
engine_kwargs : dict or None
188+
numba.jit arguments
189+
kwargs : dict
190+
kwargs for func
191+
cache_key_str : str
192+
string representing the second part of the cache key tuple
193+
194+
Returns
195+
-------
196+
(JITed function, cache key)
197+
198+
Raises
199+
------
200+
NumbaUtilError
201+
"""
202+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
203+
check_kwargs_and_nopython(kwargs, nopython)
204+
validate_udf(func)
205+
cache_key = (func, cache_key_str)
206+
numba_func = NUMBA_FUNC_CACHE.get(
207+
cache_key, jit_user_function(func, nopython, nogil, parallel)
208+
)
209+
return numba_func, cache_key

0 commit comments

Comments
 (0)