-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
/
Copy pathnumba_.py
114 lines (89 loc) · 3.01 KB
/
numba_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Common utilities for Numba operations"""
from __future__ import annotations
import types
from typing import Callable
import numpy as np
from pandas.compat._optional import import_optional_dependency
from pandas.errors import NumbaUtilError
from pandas.util.version import Version
GLOBAL_USE_NUMBA: bool = False
NUMBA_FUNC_CACHE: dict[tuple[Callable, str], Callable] = {}
def maybe_use_numba(engine: str | None) -> bool:
"""Signal whether to use numba routines."""
return engine == "numba" or (engine is None and GLOBAL_USE_NUMBA)
def set_use_numba(enable: bool = False) -> None:
global GLOBAL_USE_NUMBA
if enable:
import_optional_dependency("numba")
GLOBAL_USE_NUMBA = enable
def get_jit_arguments(
engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
) -> tuple[bool, bool, bool]:
"""
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
Parameters
----------
engine_kwargs : dict, default None
user passed keyword arguments for numba.JIT
kwargs : dict, default None
user passed keyword arguments to pass into the JITed function
Returns
-------
(bool, bool, bool)
nopython, nogil, parallel
Raises
------
NumbaUtilError
"""
if engine_kwargs is None:
engine_kwargs = {}
nopython = engine_kwargs.get("nopython", True)
if kwargs and nopython:
raise NumbaUtilError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
return nopython, nogil, parallel
def jit_user_function(
func: Callable, nopython: bool, nogil: bool, parallel: bool
) -> Callable:
"""
JIT the user's function given the configurable arguments.
Parameters
----------
func : function
user defined function
nopython : bool
nopython parameter for numba.JIT
nogil : bool
nogil parameter for numba.JIT
parallel : bool
parallel parameter for numba.JIT
Returns
-------
function
Numba JITed function
"""
numba = import_optional_dependency("numba")
if Version(numba.__version__) >= Version("0.49.0"):
is_jitted = numba.extending.is_jitted(func)
else:
is_jitted = isinstance(func, numba.targets.registry.CPUDispatcher)
if is_jitted:
# Don't jit a user passed jitted function
numba_func = func
else:
@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
def numba_func(data, *_args):
if getattr(np, func.__name__, False) is func or isinstance(
func, types.BuiltinFunctionType
):
jf = func
else:
jf = numba.jit(func, nopython=nopython, nogil=nogil)
def impl(data, *_args):
return jf(data, *_args)
return impl
return numba_func