@@ -114,7 +114,7 @@ class providing the base-class of operations.
114
114
from pandas .core .series import Series
115
115
from pandas .core .sorting import get_group_index_sorter
116
116
from pandas .core .util .numba_ import (
117
- NUMBA_FUNC_CACHE ,
117
+ get_jit_arguments ,
118
118
maybe_use_numba ,
119
119
)
120
120
@@ -1247,11 +1247,7 @@ def _resolve_numeric_only(self, numeric_only: bool | lib.NoDefault) -> bool:
1247
1247
# numba
1248
1248
1249
1249
@final
1250
- def _numba_prep (self , func , data ):
1251
- if not callable (func ):
1252
- raise NotImplementedError (
1253
- "Numba engine can only be used with a single function."
1254
- )
1250
+ def _numba_prep (self , data ):
1255
1251
ids , _ , ngroups = self .grouper .group_info
1256
1252
sorted_index = get_group_index_sorter (ids , ngroups )
1257
1253
sorted_ids = algorithms .take_nd (ids , sorted_index , allow_fill = False )
@@ -1271,7 +1267,6 @@ def _numba_agg_general(
1271
1267
self ,
1272
1268
func : Callable ,
1273
1269
engine_kwargs : dict [str , bool ] | None ,
1274
- numba_cache_key_str : str ,
1275
1270
* aggregator_args ,
1276
1271
):
1277
1272
"""
@@ -1288,16 +1283,12 @@ def _numba_agg_general(
1288
1283
with self ._group_selection_context ():
1289
1284
data = self ._selected_obj
1290
1285
df = data if data .ndim == 2 else data .to_frame ()
1291
- starts , ends , sorted_index , sorted_data = self ._numba_prep (func , df )
1286
+ starts , ends , sorted_index , sorted_data = self ._numba_prep (df )
1292
1287
aggregator = executor .generate_shared_aggregator (
1293
- func , engine_kwargs , numba_cache_key_str
1288
+ func , ** get_jit_arguments ( engine_kwargs )
1294
1289
)
1295
1290
result = aggregator (sorted_data , starts , ends , 0 , * aggregator_args )
1296
1291
1297
- cache_key = (func , numba_cache_key_str )
1298
- if cache_key not in NUMBA_FUNC_CACHE :
1299
- NUMBA_FUNC_CACHE [cache_key ] = aggregator
1300
-
1301
1292
index = self .grouper .result_index
1302
1293
if data .ndim == 1 :
1303
1294
result_kwargs = {"name" : data .name }
@@ -1315,10 +1306,10 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
1315
1306
to generate the indices of each group in the sorted data and then passes the
1316
1307
data and indices into a Numba jitted function.
1317
1308
"""
1318
- starts , ends , sorted_index , sorted_data = self ._numba_prep (func , data )
1319
-
1309
+ starts , ends , sorted_index , sorted_data = self ._numba_prep (data )
1310
+ numba_ . validate_udf ( func )
1320
1311
numba_transform_func = numba_ .generate_numba_transform_func (
1321
- kwargs , func , engine_kwargs
1312
+ func , ** get_jit_arguments ( engine_kwargs , kwargs )
1322
1313
)
1323
1314
result = numba_transform_func (
1324
1315
sorted_data ,
@@ -1328,11 +1319,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
1328
1319
len (data .columns ),
1329
1320
* args ,
1330
1321
)
1331
-
1332
- cache_key = (func , "groupby_transform" )
1333
- if cache_key not in NUMBA_FUNC_CACHE :
1334
- NUMBA_FUNC_CACHE [cache_key ] = numba_transform_func
1335
-
1336
1322
# result values needs to be resorted to their original positions since we
1337
1323
# evaluated the data sorted by group
1338
1324
return result .take (np .argsort (sorted_index ), axis = 0 )
@@ -1346,9 +1332,11 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
1346
1332
to generate the indices of each group in the sorted data and then passes the
1347
1333
data and indices into a Numba jitted function.
1348
1334
"""
1349
- starts , ends , sorted_index , sorted_data = self ._numba_prep (func , data )
1350
-
1351
- numba_agg_func = numba_ .generate_numba_agg_func (kwargs , func , engine_kwargs )
1335
+ starts , ends , sorted_index , sorted_data = self ._numba_prep (data )
1336
+ numba_ .validate_udf (func )
1337
+ numba_agg_func = numba_ .generate_numba_agg_func (
1338
+ func , ** get_jit_arguments (engine_kwargs , kwargs )
1339
+ )
1352
1340
result = numba_agg_func (
1353
1341
sorted_data ,
1354
1342
sorted_index ,
@@ -1357,11 +1345,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
1357
1345
len (data .columns ),
1358
1346
* args ,
1359
1347
)
1360
-
1361
- cache_key = (func , "groupby_agg" )
1362
- if cache_key not in NUMBA_FUNC_CACHE :
1363
- NUMBA_FUNC_CACHE [cache_key ] = numba_agg_func
1364
-
1365
1348
return result
1366
1349
1367
1350
# -----------------------------------------------------------------
@@ -1947,7 +1930,7 @@ def mean(
1947
1930
if maybe_use_numba (engine ):
1948
1931
from pandas .core ._numba .kernels import sliding_mean
1949
1932
1950
- return self ._numba_agg_general (sliding_mean , engine_kwargs , "groupby_mean" )
1933
+ return self ._numba_agg_general (sliding_mean , engine_kwargs )
1951
1934
else :
1952
1935
result = self ._cython_agg_general (
1953
1936
"mean" ,
@@ -2029,9 +2012,7 @@ def std(
2029
2012
if maybe_use_numba (engine ):
2030
2013
from pandas .core ._numba .kernels import sliding_var
2031
2014
2032
- return np .sqrt (
2033
- self ._numba_agg_general (sliding_var , engine_kwargs , "groupby_std" , ddof )
2034
- )
2015
+ return np .sqrt (self ._numba_agg_general (sliding_var , engine_kwargs , ddof ))
2035
2016
else :
2036
2017
return self ._get_cythonized_result (
2037
2018
libgroupby .group_var ,
@@ -2085,9 +2066,7 @@ def var(
2085
2066
if maybe_use_numba (engine ):
2086
2067
from pandas .core ._numba .kernels import sliding_var
2087
2068
2088
- return self ._numba_agg_general (
2089
- sliding_var , engine_kwargs , "groupby_var" , ddof
2090
- )
2069
+ return self ._numba_agg_general (sliding_var , engine_kwargs , ddof )
2091
2070
else :
2092
2071
if ddof == 1 :
2093
2072
numeric_only = self ._resolve_numeric_only (lib .no_default )
@@ -2180,7 +2159,6 @@ def sum(
2180
2159
return self ._numba_agg_general (
2181
2160
sliding_sum ,
2182
2161
engine_kwargs ,
2183
- "groupby_sum" ,
2184
2162
)
2185
2163
else :
2186
2164
numeric_only = self ._resolve_numeric_only (numeric_only )
@@ -2221,9 +2199,7 @@ def min(
2221
2199
if maybe_use_numba (engine ):
2222
2200
from pandas .core ._numba .kernels import sliding_min_max
2223
2201
2224
- return self ._numba_agg_general (
2225
- sliding_min_max , engine_kwargs , "groupby_min" , False
2226
- )
2202
+ return self ._numba_agg_general (sliding_min_max , engine_kwargs , False )
2227
2203
else :
2228
2204
return self ._agg_general (
2229
2205
numeric_only = numeric_only ,
@@ -2244,9 +2220,7 @@ def max(
2244
2220
if maybe_use_numba (engine ):
2245
2221
from pandas .core ._numba .kernels import sliding_min_max
2246
2222
2247
- return self ._numba_agg_general (
2248
- sliding_min_max , engine_kwargs , "groupby_max" , True
2249
- )
2223
+ return self ._numba_agg_general (sliding_min_max , engine_kwargs , True )
2250
2224
else :
2251
2225
return self ._agg_general (
2252
2226
numeric_only = numeric_only ,
0 commit comments