4
4
5
5
import numpy as np
6
6
7
- from pandas ._typing import FrameOrSeries , Scalar
7
+ from pandas ._typing import Scalar
8
8
from pandas .compat ._optional import import_optional_dependency
9
9
10
10
from pandas .core .util .numba_ import (
11
11
NUMBA_FUNC_CACHE ,
12
12
NumbaUtilError ,
13
- check_kwargs_and_nopython ,
14
13
get_jit_arguments ,
15
14
jit_user_function ,
16
15
)
17
16
18
17
19
- def split_for_numba (arg : FrameOrSeries ) -> Tuple [np .ndarray , np .ndarray ]:
20
- """
21
- Split pandas object into its components as numpy arrays for numba functions.
22
-
23
- Parameters
24
- ----------
25
- arg : Series or DataFrame
26
-
27
- Returns
28
- -------
29
- (ndarray, ndarray)
30
- values, index
31
- """
32
- return arg .to_numpy (), arg .index .to_numpy ()
33
-
34
-
35
18
def validate_udf (func : Callable ) -> None :
36
19
"""
37
20
Validate user defined function for ops when using Numba with groupby ops.
@@ -67,46 +50,6 @@ def f(values, index, ...):
67
50
)
68
51
69
52
70
- def generate_numba_func (
71
- func : Callable ,
72
- engine_kwargs : Optional [Dict [str , bool ]],
73
- kwargs : dict ,
74
- cache_key_str : str ,
75
- ) -> Tuple [Callable , Tuple [Callable , str ]]:
76
- """
77
- Return a JITed function and cache key for the NUMBA_FUNC_CACHE
78
-
79
- This _may_ be specific to groupby (as it's only used there currently).
80
-
81
- Parameters
82
- ----------
83
- func : function
84
- user defined function
85
- engine_kwargs : dict or None
86
- numba.jit arguments
87
- kwargs : dict
88
- kwargs for func
89
- cache_key_str : str
90
- string representing the second part of the cache key tuple
91
-
92
- Returns
93
- -------
94
- (JITed function, cache key)
95
-
96
- Raises
97
- ------
98
- NumbaUtilError
99
- """
100
- nopython , nogil , parallel = get_jit_arguments (engine_kwargs )
101
- check_kwargs_and_nopython (kwargs , nopython )
102
- validate_udf (func )
103
- cache_key = (func , cache_key_str )
104
- numba_func = NUMBA_FUNC_CACHE .get (
105
- cache_key , jit_user_function (func , nopython , nogil , parallel )
106
- )
107
- return numba_func , cache_key
108
-
109
-
110
53
def generate_numba_agg_func (
111
54
args : Tuple ,
112
55
kwargs : Dict [str , Any ],
@@ -120,7 +63,7 @@ def generate_numba_agg_func(
120
63
2. Return a groupby agg function with the jitted function inline
121
64
122
65
Configurations specified in engine_kwargs apply to both the user's
123
- function _AND_ the rolling apply function .
66
+ function _AND_ the groupby evaluation loop .
124
67
125
68
Parameters
126
69
----------
@@ -137,16 +80,15 @@ def generate_numba_agg_func(
137
80
-------
138
81
Numba function
139
82
"""
140
- nopython , nogil , parallel = get_jit_arguments (engine_kwargs )
141
-
142
- check_kwargs_and_nopython (kwargs , nopython )
83
+ nopython , nogil , parallel = get_jit_arguments (engine_kwargs , kwargs )
143
84
144
85
validate_udf (func )
86
+ cache_key = (func , "groupby_agg" )
87
+ if cache_key in NUMBA_FUNC_CACHE :
88
+ return NUMBA_FUNC_CACHE [cache_key ]
145
89
146
90
numba_func = jit_user_function (func , nopython , nogil , parallel )
147
-
148
91
numba = import_optional_dependency ("numba" )
149
-
150
92
if parallel :
151
93
loop_range = numba .prange
152
94
else :
@@ -175,17 +117,17 @@ def group_agg(
175
117
def generate_numba_transform_func (
176
118
args : Tuple ,
177
119
kwargs : Dict [str , Any ],
178
- func : Callable [..., Scalar ],
120
+ func : Callable [..., np . ndarray ],
179
121
engine_kwargs : Optional [Dict [str , bool ]],
180
122
) -> Callable [[np .ndarray , np .ndarray , np .ndarray , np .ndarray , int , int ], np .ndarray ]:
181
123
"""
182
124
Generate a numba jitted transform function specified by values from engine_kwargs.
183
125
184
126
1. jit the user's function
185
- 2. Return a groupby agg function with the jitted function inline
127
+ 2. Return a groupby transform function with the jitted function inline
186
128
187
129
Configurations specified in engine_kwargs apply to both the user's
188
- function _AND_ the rolling apply function .
130
+ function _AND_ the groupby evaluation loop .
189
131
190
132
Parameters
191
133
----------
@@ -202,16 +144,15 @@ def generate_numba_transform_func(
202
144
-------
203
145
Numba function
204
146
"""
205
- nopython , nogil , parallel = get_jit_arguments (engine_kwargs )
206
-
207
- check_kwargs_and_nopython (kwargs , nopython )
147
+ nopython , nogil , parallel = get_jit_arguments (engine_kwargs , kwargs )
208
148
209
149
validate_udf (func )
150
+ cache_key = (func , "groupby_transform" )
151
+ if cache_key in NUMBA_FUNC_CACHE :
152
+ return NUMBA_FUNC_CACHE [cache_key ]
210
153
211
154
numba_func = jit_user_function (func , nopython , nogil , parallel )
212
-
213
155
numba = import_optional_dependency ("numba" )
214
-
215
156
if parallel :
216
157
loop_range = numba .prange
217
158
else :
0 commit comments