76
76
from pandas .core .internals import BlockManager , make_block
77
77
from pandas .core .series import Series
78
78
from pandas .core .util .numba_ import (
79
+ NUMBA_FUNC_CACHE ,
79
80
check_kwargs_and_nopython ,
80
81
get_jit_arguments ,
81
82
jit_user_function ,
@@ -161,8 +162,6 @@ def pinner(cls):
161
162
class SeriesGroupBy (GroupBy [Series ]):
162
163
_apply_whitelist = base .series_apply_whitelist
163
164
164
- _numba_func_cache : Dict [Callable , Callable ] = {}
165
-
166
165
def _iterate_slices (self ) -> Iterable [Series ]:
167
166
yield self ._selected_obj
168
167
@@ -504,8 +503,9 @@ def _transform_general(
504
503
nopython , nogil , parallel = get_jit_arguments (engine_kwargs )
505
504
check_kwargs_and_nopython (kwargs , nopython )
506
505
validate_udf (func )
507
- numba_func = self ._numba_func_cache .get (
508
- func , jit_user_function (func , nopython , nogil , parallel )
506
+ cache_key = (func , "groupby_transform" )
507
+ numba_func = NUMBA_FUNC_CACHE .get (
508
+ cache_key , jit_user_function (func , nopython , nogil , parallel )
509
509
)
510
510
511
511
klass = type (self ._selected_obj )
@@ -516,8 +516,8 @@ def _transform_general(
516
516
if engine == "numba" :
517
517
values , index = split_for_numba (group )
518
518
res = numba_func (values , index , * args )
519
- if func not in self . _numba_func_cache :
520
- self . _numba_func_cache [ func ] = numba_func
519
+ if cache_key not in NUMBA_FUNC_CACHE :
520
+ NUMBA_FUNC_CACHE [ cache_key ] = numba_func
521
521
else :
522
522
res = func (group , * args , ** kwargs )
523
523
@@ -847,8 +847,6 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
847
847
848
848
_apply_whitelist = base .dataframe_apply_whitelist
849
849
850
- _numba_func_cache : Dict [Callable , Callable ] = {}
851
-
852
850
_agg_see_also_doc = dedent (
853
851
"""
854
852
See Also
@@ -1397,8 +1395,9 @@ def _transform_general(
1397
1395
nopython , nogil , parallel = get_jit_arguments (engine_kwargs )
1398
1396
check_kwargs_and_nopython (kwargs , nopython )
1399
1397
validate_udf (func )
1400
- numba_func = self ._numba_func_cache .get (
1401
- func , jit_user_function (func , nopython , nogil , parallel )
1398
+ cache_key = (func , "groupby_transform" )
1399
+ numba_func = NUMBA_FUNC_CACHE .get (
1400
+ cache_key , jit_user_function (func , nopython , nogil , parallel )
1402
1401
)
1403
1402
else :
1404
1403
fast_path , slow_path = self ._define_paths (func , * args , ** kwargs )
@@ -1409,8 +1408,8 @@ def _transform_general(
1409
1408
if engine == "numba" :
1410
1409
values , index = split_for_numba (group )
1411
1410
res = numba_func (values , index , * args )
1412
- if func not in self . _numba_func_cache :
1413
- self . _numba_func_cache [ func ] = numba_func
1411
+ if cache_key not in NUMBA_FUNC_CACHE :
1412
+ NUMBA_FUNC_CACHE [ cache_key ] = numba_func
1414
1413
# Return the result as a DataFrame for concatenation later
1415
1414
res = DataFrame (res , index = group .index , columns = group .columns )
1416
1415
else :
0 commit comments