@@ -17,6 +17,7 @@ def generate_numba_apply_func(
17
17
kwargs : Dict [str , Any ],
18
18
func : Callable [..., Scalar ],
19
19
engine_kwargs : Optional [Dict [str , bool ]],
20
+ name : str ,
20
21
):
21
22
"""
22
23
Generate a numba jitted apply function specified by values from engine_kwargs.
@@ -37,14 +38,16 @@ def generate_numba_apply_func(
37
38
function to be applied to each window and will be JITed
38
39
engine_kwargs : dict
39
40
dictionary of arguments to be passed into numba.jit
41
+ name: str
42
+ name of the caller (Rolling/Expanding)
40
43
41
44
Returns
42
45
-------
43
46
Numba function
44
47
"""
45
48
nopython , nogil , parallel = get_jit_arguments (engine_kwargs , kwargs )
46
49
47
- cache_key = (func , "rolling_apply " )
50
+ cache_key = (func , f" { name } _apply_single " )
48
51
if cache_key in NUMBA_FUNC_CACHE :
49
52
return NUMBA_FUNC_CACHE [cache_key ]
50
53
@@ -153,3 +156,67 @@ def groupby_ewma(
153
156
return result
154
157
155
158
return groupby_ewma
159
+
160
+
161
+ def generate_numba_table_func (
162
+ args : Tuple ,
163
+ kwargs : Dict [str , Any ],
164
+ func : Callable [..., np .ndarray ],
165
+ engine_kwargs : Optional [Dict [str , bool ]],
166
+ name : str ,
167
+ ):
168
+ """
169
+ Generate a numba jitted function to apply window calculations table-wise.
170
+
171
+ Func will be passed a M window size x N number of columns array, and
172
+ must return a 1 x N number of columns array. Func is intended to operate
173
+ row-wise, but the result will be transposed for axis=1.
174
+
175
+ 1. jit the user's function
176
+ 2. Return a rolling apply function with the jitted function inline
177
+
178
+ Parameters
179
+ ----------
180
+ args : tuple
181
+ *args to be passed into the function
182
+ kwargs : dict
183
+ **kwargs to be passed into the function
184
+ func : function
185
+ function to be applied to each window and will be JITed
186
+ engine_kwargs : dict
187
+ dictionary of arguments to be passed into numba.jit
188
+ name : str
189
+ caller (Rolling/Expanding) and original method name for numba cache key
190
+
191
+ Returns
192
+ -------
193
+ Numba function
194
+ """
195
+ nopython , nogil , parallel = get_jit_arguments (engine_kwargs , kwargs )
196
+
197
+ cache_key = (func , f"{ name } _table" )
198
+ if cache_key in NUMBA_FUNC_CACHE :
199
+ return NUMBA_FUNC_CACHE [cache_key ]
200
+
201
+ numba_func = jit_user_function (func , nopython , nogil , parallel )
202
+ numba = import_optional_dependency ("numba" )
203
+
204
+ @numba .jit (nopython = nopython , nogil = nogil , parallel = parallel )
205
+ def roll_table (
206
+ values : np .ndarray , begin : np .ndarray , end : np .ndarray , minimum_periods : int
207
+ ):
208
+ result = np .empty (values .shape )
209
+ min_periods_mask = np .empty (values .shape )
210
+ for i in numba .prange (len (result )):
211
+ start = begin [i ]
212
+ stop = end [i ]
213
+ window = values [start :stop ]
214
+ count_nan = np .sum (np .isnan (window ), axis = 0 )
215
+ sub_result = numba_func (window , * args )
216
+ nan_mask = len (window ) - count_nan >= minimum_periods
217
+ min_periods_mask [i , :] = nan_mask
218
+ result [i , :] = sub_result
219
+ result = np .where (min_periods_mask , result , np .nan )
220
+ return result
221
+
222
+ return roll_table
0 commit comments