forked from pandas-dev/pandas
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnumba_.py
100 lines (79 loc) · 2.88 KB
/
numba_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import types
from typing import Any, Callable, Dict, Optional, Tuple
import numpy as np
from pandas._typing import Scalar
from pandas.compat._optional import import_optional_dependency
def make_rolling_apply(func, args, nogil, parallel, nopython):
numba = import_optional_dependency("numba")
if parallel:
loop_range = numba.prange
else:
loop_range = range
if isinstance(func, numba.targets.registry.CPUDispatcher):
# Don't jit a user passed jitted function
numba_func = func
else:
@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
def numba_func(window, *_args):
if getattr(np, func.__name__, False) is func or isinstance(
func, types.BuiltinFunctionType
):
jf = func
else:
jf = numba.jit(func, nopython=nopython)
def impl(window, *_args):
return jf(window, *_args)
return impl
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def roll_apply(
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int,
) -> np.ndarray:
result = np.empty(len(begin))
for i in loop_range(len(result)):
start = begin[i]
stop = end[i]
window = values[start:stop]
count_nan = np.sum(np.isnan(window))
if len(window) - count_nan >= minimum_periods:
result[i] = numba_func(window, *args)
else:
result[i] = np.nan
return result
return roll_apply
def generate_numba_apply_func(
args: Tuple,
kwargs: Dict[str, Any],
func: Callable[..., Scalar],
engine_kwargs: Optional[Dict[str, bool]],
):
"""
Generate a numba jitted apply function specified by values from engine_kwargs.
1. jit the user's function
2. Return a rolling apply function with the jitted function inline
Configurations specified in engine_kwargs apply to both the user's
function _AND_ the rolling apply function.
Parameters
----------
args : tuple
*args to be passed into the function
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit
Returns
-------
Numba function
"""
if engine_kwargs is None:
engine_kwargs = {}
nopython = engine_kwargs.get("nopython", True)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
if kwargs and nopython:
raise ValueError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)
return make_rolling_apply(func, args, nogil, parallel, nopython)