Skip to content

Commit cfbd7f6

Browse files
authored
CLN: Consolidate numba facilities (#32770)
1 parent 91512a8 commit cfbd7f6

File tree

2 files changed

+87
-81
lines changed

2 files changed

+87
-81
lines changed

pandas/core/util/numba_.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Common utilities for Numba operations"""
2+
import types
3+
from typing import Callable, Dict, Optional
4+
5+
import numpy as np
6+
7+
from pandas.compat._optional import import_optional_dependency
8+
9+
10+
def check_kwargs_and_nopython(
11+
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None
12+
):
13+
if kwargs and nopython:
14+
raise ValueError(
15+
"numba does not support kwargs with nopython=True: "
16+
"https://github.com/numba/numba/issues/2916"
17+
)
18+
19+
20+
def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None):
21+
"""
22+
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
23+
"""
24+
if engine_kwargs is None:
25+
engine_kwargs = {}
26+
27+
nopython = engine_kwargs.get("nopython", True)
28+
nogil = engine_kwargs.get("nogil", False)
29+
parallel = engine_kwargs.get("parallel", False)
30+
return nopython, nogil, parallel
31+
32+
33+
def jit_user_function(func: Callable, nopython: bool, nogil: bool, parallel: bool):
34+
"""
35+
JIT the user's function given the configurable arguments.
36+
"""
37+
numba = import_optional_dependency("numba")
38+
39+
if isinstance(func, numba.targets.registry.CPUDispatcher):
40+
# Don't jit a user passed jitted function
41+
numba_func = func
42+
else:
43+
44+
@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
45+
def numba_func(data, *_args):
46+
if getattr(np, func.__name__, False) is func or isinstance(
47+
func, types.BuiltinFunctionType
48+
):
49+
jf = func
50+
else:
51+
jf = numba.jit(func, nopython=nopython, nogil=nogil)
52+
53+
def impl(data, *_args):
54+
return jf(data, *_args)
55+
56+
return impl
57+
58+
return numba_func

pandas/core/window/numba_.py

+29-81
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,60 @@
1-
import types
21
from typing import Any, Callable, Dict, Optional, Tuple
32

43
import numpy as np
54

65
from pandas._typing import Scalar
76
from pandas.compat._optional import import_optional_dependency
87

8+
from pandas.core.util.numba_ import (
9+
check_kwargs_and_nopython,
10+
get_jit_arguments,
11+
jit_user_function,
12+
)
913

10-
def make_rolling_apply(
11-
func: Callable[..., Scalar],
14+
15+
def generate_numba_apply_func(
1216
args: Tuple,
13-
nogil: bool,
14-
parallel: bool,
15-
nopython: bool,
17+
kwargs: Dict[str, Any],
18+
func: Callable[..., Scalar],
19+
engine_kwargs: Optional[Dict[str, bool]],
1620
):
1721
"""
18-
Creates a JITted rolling apply function with a JITted version of
19-
the user's function.
22+
Generate a numba jitted apply function specified by values from engine_kwargs.
23+
24+
1. jit the user's function
25+
2. Return a rolling apply function with the jitted function inline
26+
27+
Configurations specified in engine_kwargs apply to both the user's
28+
function _AND_ the rolling apply function.
2029
2130
Parameters
2231
----------
23-
func : function
24-
function to be applied to each window and will be JITed
2532
args : tuple
2633
*args to be passed into the function
27-
nogil : bool
28-
nogil parameter from engine_kwargs for numba.jit
29-
parallel : bool
30-
parallel parameter from engine_kwargs for numba.jit
31-
nopython : bool
32-
nopython parameter from engine_kwargs for numba.jit
34+
kwargs : dict
35+
**kwargs to be passed into the function
36+
func : function
37+
function to be applied to each window and will be JITed
38+
engine_kwargs : dict
39+
dictionary of arguments to be passed into numba.jit
3340
3441
Returns
3542
-------
3643
Numba function
3744
"""
45+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
46+
47+
check_kwargs_and_nopython(kwargs, nopython)
48+
49+
numba_func = jit_user_function(func, nopython, nogil, parallel)
50+
3851
numba = import_optional_dependency("numba")
3952

4053
if parallel:
4154
loop_range = numba.prange
4255
else:
4356
loop_range = range
4457

45-
if isinstance(func, numba.targets.registry.CPUDispatcher):
46-
# Don't jit a user passed jitted function
47-
numba_func = func
48-
else:
49-
50-
@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
51-
def numba_func(window, *_args):
52-
if getattr(np, func.__name__, False) is func or isinstance(
53-
func, types.BuiltinFunctionType
54-
):
55-
jf = func
56-
else:
57-
jf = numba.jit(func, nopython=nopython, nogil=nogil)
58-
59-
def impl(window, *_args):
60-
return jf(window, *_args)
61-
62-
return impl
63-
6458
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
6559
def roll_apply(
6660
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int,
@@ -78,49 +72,3 @@ def roll_apply(
7872
return result
7973

8074
return roll_apply
81-
82-
83-
def generate_numba_apply_func(
84-
args: Tuple,
85-
kwargs: Dict[str, Any],
86-
func: Callable[..., Scalar],
87-
engine_kwargs: Optional[Dict[str, bool]],
88-
):
89-
"""
90-
Generate a numba jitted apply function specified by values from engine_kwargs.
91-
92-
1. jit the user's function
93-
2. Return a rolling apply function with the jitted function inline
94-
95-
Configurations specified in engine_kwargs apply to both the user's
96-
function _AND_ the rolling apply function.
97-
98-
Parameters
99-
----------
100-
args : tuple
101-
*args to be passed into the function
102-
kwargs : dict
103-
**kwargs to be passed into the function
104-
func : function
105-
function to be applied to each window and will be JITed
106-
engine_kwargs : dict
107-
dictionary of arguments to be passed into numba.jit
108-
109-
Returns
110-
-------
111-
Numba function
112-
"""
113-
if engine_kwargs is None:
114-
engine_kwargs = {}
115-
116-
nopython = engine_kwargs.get("nopython", True)
117-
nogil = engine_kwargs.get("nogil", False)
118-
parallel = engine_kwargs.get("parallel", False)
119-
120-
if kwargs and nopython:
121-
raise ValueError(
122-
"numba does not support kwargs with nopython=True: "
123-
"https://github.com/numba/numba/issues/2916"
124-
)
125-
126-
return make_rolling_apply(func, args, nogil, parallel, nopython)

0 commit comments

Comments
 (0)