forked from pandas-dev/pandas
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnumba_.py
111 lines (88 loc) · 3.02 KB
/
numba_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Common utilities for Numba operations"""
from distutils.version import LooseVersion
import types
from typing import Callable, Dict, Optional, Tuple
import numpy as np
from pandas.compat._optional import import_optional_dependency
from pandas.errors import NumbaUtilError
GLOBAL_USE_NUMBA: bool = False
NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = dict()
def maybe_use_numba(engine: Optional[str]) -> bool:
"""Signal whether to use numba routines."""
return engine == "numba" or (engine is None and GLOBAL_USE_NUMBA)
def set_use_numba(enable: bool = False) -> None:
global GLOBAL_USE_NUMBA
if enable:
import_optional_dependency("numba")
GLOBAL_USE_NUMBA = enable
def get_jit_arguments(
engine_kwargs: Optional[Dict[str, bool]] = None, kwargs: Optional[Dict] = None,
) -> Tuple[bool, bool, bool]:
"""
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
Parameters
----------
engine_kwargs : dict, default None
user passed keyword arguments for numba.JIT
kwargs : dict, default None
user passed keyword arguments to pass into the JITed function
Returns
-------
(bool, bool, bool)
nopython, nogil, parallel
Raises
------
NumbaUtilError
"""
if engine_kwargs is None:
engine_kwargs = {}
nopython = engine_kwargs.get("nopython", True)
if kwargs and nopython:
raise NumbaUtilError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
return nopython, nogil, parallel
def jit_user_function(
func: Callable, nopython: bool, nogil: bool, parallel: bool
) -> Callable:
"""
JIT the user's function given the configurable arguments.
Parameters
----------
func : function
user defined function
nopython : bool
nopython parameter for numba.JIT
nogil : bool
nogil parameter for numba.JIT
parallel : bool
parallel parameter for numba.JIT
Returns
-------
function
Numba JITed function
"""
numba = import_optional_dependency("numba")
if LooseVersion(numba.__version__) >= LooseVersion("0.49.0"):
is_jitted = numba.extending.is_jitted(func)
else:
is_jitted = isinstance(func, numba.targets.registry.CPUDispatcher)
if is_jitted:
# Don't jit a user passed jitted function
numba_func = func
else:
@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
def numba_func(data, *_args):
if getattr(np, func.__name__, False) is func or isinstance(
func, types.BuiltinFunctionType
):
jf = func
else:
jf = numba.jit(func, nopython=nopython, nogil=nogil)
def impl(data, *_args):
return jf(data, *_args)
return impl
return numba_func