-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
/
Copy pathnumba_.py
179 lines (144 loc) · 5.1 KB
/
numba_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Common utilities for Numba operations with groupby ops"""
from __future__ import annotations
import inspect
from typing import (
Any,
Callable,
)
import numpy as np
from pandas._typing import Scalar
from pandas.compat._optional import import_optional_dependency
from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
NumbaUtilError,
get_jit_arguments,
jit_user_function,
)
def validate_udf(func: Callable) -> None:
"""
Validate user defined function for ops when using Numba with groupby ops.
The first signature arguments should include:
def f(values, index, ...):
...
Parameters
----------
func : function, default False
user defined function
Returns
-------
None
Raises
------
NumbaUtilError
"""
udf_signature = list(inspect.signature(func).parameters.keys())
expected_args = ["values", "index"]
min_number_args = len(expected_args)
if (
len(udf_signature) < min_number_args
or udf_signature[:min_number_args] != expected_args
):
raise NumbaUtilError(
f"The first {min_number_args} arguments to {func.__name__} must be "
f"{expected_args}"
)
def generate_numba_agg_func(
kwargs: dict[str, Any],
func: Callable[..., Scalar],
engine_kwargs: dict[str, bool] | None,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
"""
Generate a numba jitted agg function specified by values from engine_kwargs.
1. jit the user's function
2. Return a groupby agg function with the jitted function inline
Configurations specified in engine_kwargs apply to both the user's
function _AND_ the groupby evaluation loop.
Parameters
----------
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit
Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
validate_udf(func)
cache_key = (func, "groupby_agg")
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]
numba_func = jit_user_function(func, nopython, nogil, parallel)
numba = import_optional_dependency("numba")
# error: Untyped decorator makes function "group_agg" untyped
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
def group_agg(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
num_columns: int,
*args: Any,
) -> np.ndarray:
assert len(begin) == len(end)
num_groups = len(begin)
result = np.empty((num_groups, num_columns))
for i in numba.prange(num_groups):
group_index = index[begin[i] : end[i]]
for j in numba.prange(num_columns):
group = values[begin[i] : end[i], j]
result[i, j] = numba_func(group, group_index, *args)
return result
return group_agg
def generate_numba_transform_func(
kwargs: dict[str, Any],
func: Callable[..., np.ndarray],
engine_kwargs: dict[str, bool] | None,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
"""
Generate a numba jitted transform function specified by values from engine_kwargs.
1. jit the user's function
2. Return a groupby transform function with the jitted function inline
Configurations specified in engine_kwargs apply to both the user's
function _AND_ the groupby evaluation loop.
Parameters
----------
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit
Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
validate_udf(func)
cache_key = (func, "groupby_transform")
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]
numba_func = jit_user_function(func, nopython, nogil, parallel)
numba = import_optional_dependency("numba")
# error: Untyped decorator makes function "group_transform" untyped
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) # type: ignore[misc]
def group_transform(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
num_columns: int,
*args: Any,
) -> np.ndarray:
assert len(begin) == len(end)
num_groups = len(begin)
result = np.empty((len(values), num_columns))
for i in numba.prange(num_groups):
group_index = index[begin[i] : end[i]]
for j in numba.prange(num_columns):
group = values[begin[i] : end[i], j]
result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
return result
return group_transform