Skip to content

Commit 696a8e9

Browse files
authored
BUG/REF: Use lru_cache instead of NUMBA_FUNC_CACHE (#46086)
1 parent 66a5de3 commit 696a8e9

File tree

12 files changed

+230
-260
lines changed

12 files changed

+230
-260
lines changed

doc/source/whatsnew/v1.5.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ Groupby/resample/rolling
383383
- Bug in :meth:`DataFrame.resample` ignoring ``closed="right"`` on :class:`TimedeltaIndex` (:issue:`45414`)
384384
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``func="size"`` and the input DataFrame has multiple columns (:issue:`27469`)
385385
- Bug in :meth:`.DataFrameGroupBy.size` and :meth:`.DataFrameGroupBy.transform` with ``func="size"`` produced incorrect results when ``axis=1`` (:issue:`45715`)
386+
- Bug in :meth:`.ExponentialMovingWindow.mean` with ``axis=1`` and ``engine='numba'`` when the :class:`DataFrame` has more columns than rows (:issue:`46086`)
387+
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
386388
- Bug in :meth:`.DataFrameGroupby.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
387389

388390
Reshaping

pandas/core/_numba/executor.py

+11-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
from typing import (
45
TYPE_CHECKING,
56
Callable,
@@ -10,16 +11,13 @@
1011
from pandas._typing import Scalar
1112
from pandas.compat._optional import import_optional_dependency
1213

13-
from pandas.core.util.numba_ import (
14-
NUMBA_FUNC_CACHE,
15-
get_jit_arguments,
16-
)
17-
1814

15+
@functools.lru_cache(maxsize=None)
1916
def generate_shared_aggregator(
2017
func: Callable[..., Scalar],
21-
engine_kwargs: dict[str, bool] | None,
22-
cache_key_str: str,
18+
nopython: bool,
19+
nogil: bool,
20+
parallel: bool,
2321
):
2422
"""
2523
Generate a Numba function that loops over the columns 2D object and applies
@@ -29,22 +27,17 @@ def generate_shared_aggregator(
2927
----------
3028
func : function
3129
aggregation function to be applied to each column
32-
engine_kwargs : dict
33-
dictionary of arguments to be passed into numba.jit
34-
cache_key_str: str
35-
string to access the compiled function of the form
36-
<caller_type>_<aggregation_type> e.g. rolling_mean, groupby_mean
30+
nopython : bool
31+
nopython to be passed into numba.jit
32+
nogil : bool
33+
nogil to be passed into numba.jit
34+
parallel : bool
35+
parallel to be passed into numba.jit
3736
3837
Returns
3938
-------
4039
Numba function
4140
"""
42-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, None)
43-
44-
cache_key = (func, cache_key_str)
45-
if cache_key in NUMBA_FUNC_CACHE:
46-
return NUMBA_FUNC_CACHE[cache_key]
47-
4841
if TYPE_CHECKING:
4942
import numba
5043
else:

pandas/core/groupby/groupby.py

+17-43
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class providing the base-class of operations.
114114
from pandas.core.series import Series
115115
from pandas.core.sorting import get_group_index_sorter
116116
from pandas.core.util.numba_ import (
117-
NUMBA_FUNC_CACHE,
117+
get_jit_arguments,
118118
maybe_use_numba,
119119
)
120120

@@ -1247,11 +1247,7 @@ def _resolve_numeric_only(self, numeric_only: bool | lib.NoDefault) -> bool:
12471247
# numba
12481248

12491249
@final
1250-
def _numba_prep(self, func, data):
1251-
if not callable(func):
1252-
raise NotImplementedError(
1253-
"Numba engine can only be used with a single function."
1254-
)
1250+
def _numba_prep(self, data):
12551251
ids, _, ngroups = self.grouper.group_info
12561252
sorted_index = get_group_index_sorter(ids, ngroups)
12571253
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)
@@ -1271,7 +1267,6 @@ def _numba_agg_general(
12711267
self,
12721268
func: Callable,
12731269
engine_kwargs: dict[str, bool] | None,
1274-
numba_cache_key_str: str,
12751270
*aggregator_args,
12761271
):
12771272
"""
@@ -1288,16 +1283,12 @@ def _numba_agg_general(
12881283
with self._group_selection_context():
12891284
data = self._selected_obj
12901285
df = data if data.ndim == 2 else data.to_frame()
1291-
starts, ends, sorted_index, sorted_data = self._numba_prep(func, df)
1286+
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
12921287
aggregator = executor.generate_shared_aggregator(
1293-
func, engine_kwargs, numba_cache_key_str
1288+
func, **get_jit_arguments(engine_kwargs)
12941289
)
12951290
result = aggregator(sorted_data, starts, ends, 0, *aggregator_args)
12961291

1297-
cache_key = (func, numba_cache_key_str)
1298-
if cache_key not in NUMBA_FUNC_CACHE:
1299-
NUMBA_FUNC_CACHE[cache_key] = aggregator
1300-
13011292
index = self.grouper.result_index
13021293
if data.ndim == 1:
13031294
result_kwargs = {"name": data.name}
@@ -1315,10 +1306,10 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13151306
to generate the indices of each group in the sorted data and then passes the
13161307
data and indices into a Numba jitted function.
13171308
"""
1318-
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)
1319-
1309+
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
1310+
numba_.validate_udf(func)
13201311
numba_transform_func = numba_.generate_numba_transform_func(
1321-
kwargs, func, engine_kwargs
1312+
func, **get_jit_arguments(engine_kwargs, kwargs)
13221313
)
13231314
result = numba_transform_func(
13241315
sorted_data,
@@ -1328,11 +1319,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13281319
len(data.columns),
13291320
*args,
13301321
)
1331-
1332-
cache_key = (func, "groupby_transform")
1333-
if cache_key not in NUMBA_FUNC_CACHE:
1334-
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func
1335-
13361322
# result values needs to be resorted to their original positions since we
13371323
# evaluated the data sorted by group
13381324
return result.take(np.argsort(sorted_index), axis=0)
@@ -1346,9 +1332,11 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13461332
to generate the indices of each group in the sorted data and then passes the
13471333
data and indices into a Numba jitted function.
13481334
"""
1349-
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)
1350-
1351-
numba_agg_func = numba_.generate_numba_agg_func(kwargs, func, engine_kwargs)
1335+
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
1336+
numba_.validate_udf(func)
1337+
numba_agg_func = numba_.generate_numba_agg_func(
1338+
func, **get_jit_arguments(engine_kwargs, kwargs)
1339+
)
13521340
result = numba_agg_func(
13531341
sorted_data,
13541342
sorted_index,
@@ -1357,11 +1345,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13571345
len(data.columns),
13581346
*args,
13591347
)
1360-
1361-
cache_key = (func, "groupby_agg")
1362-
if cache_key not in NUMBA_FUNC_CACHE:
1363-
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func
1364-
13651348
return result
13661349

13671350
# -----------------------------------------------------------------
@@ -1947,7 +1930,7 @@ def mean(
19471930
if maybe_use_numba(engine):
19481931
from pandas.core._numba.kernels import sliding_mean
19491932

1950-
return self._numba_agg_general(sliding_mean, engine_kwargs, "groupby_mean")
1933+
return self._numba_agg_general(sliding_mean, engine_kwargs)
19511934
else:
19521935
result = self._cython_agg_general(
19531936
"mean",
@@ -2029,9 +2012,7 @@ def std(
20292012
if maybe_use_numba(engine):
20302013
from pandas.core._numba.kernels import sliding_var
20312014

2032-
return np.sqrt(
2033-
self._numba_agg_general(sliding_var, engine_kwargs, "groupby_std", ddof)
2034-
)
2015+
return np.sqrt(self._numba_agg_general(sliding_var, engine_kwargs, ddof))
20352016
else:
20362017
return self._get_cythonized_result(
20372018
libgroupby.group_var,
@@ -2085,9 +2066,7 @@ def var(
20852066
if maybe_use_numba(engine):
20862067
from pandas.core._numba.kernels import sliding_var
20872068

2088-
return self._numba_agg_general(
2089-
sliding_var, engine_kwargs, "groupby_var", ddof
2090-
)
2069+
return self._numba_agg_general(sliding_var, engine_kwargs, ddof)
20912070
else:
20922071
if ddof == 1:
20932072
numeric_only = self._resolve_numeric_only(lib.no_default)
@@ -2180,7 +2159,6 @@ def sum(
21802159
return self._numba_agg_general(
21812160
sliding_sum,
21822161
engine_kwargs,
2183-
"groupby_sum",
21842162
)
21852163
else:
21862164
numeric_only = self._resolve_numeric_only(numeric_only)
@@ -2221,9 +2199,7 @@ def min(
22212199
if maybe_use_numba(engine):
22222200
from pandas.core._numba.kernels import sliding_min_max
22232201

2224-
return self._numba_agg_general(
2225-
sliding_min_max, engine_kwargs, "groupby_min", False
2226-
)
2202+
return self._numba_agg_general(sliding_min_max, engine_kwargs, False)
22272203
else:
22282204
return self._agg_general(
22292205
numeric_only=numeric_only,
@@ -2244,9 +2220,7 @@ def max(
22442220
if maybe_use_numba(engine):
22452221
from pandas.core._numba.kernels import sliding_min_max
22462222

2247-
return self._numba_agg_general(
2248-
sliding_min_max, engine_kwargs, "groupby_max", True
2249-
)
2223+
return self._numba_agg_general(sliding_min_max, engine_kwargs, True)
22502224
else:
22512225
return self._agg_general(
22522226
numeric_only=numeric_only,

pandas/core/groupby/numba_.py

+26-29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Common utilities for Numba operations with groupby ops"""
22
from __future__ import annotations
33

4+
import functools
45
import inspect
56
from typing import (
67
TYPE_CHECKING,
@@ -14,9 +15,7 @@
1415
from pandas.compat._optional import import_optional_dependency
1516

1617
from pandas.core.util.numba_ import (
17-
NUMBA_FUNC_CACHE,
1818
NumbaUtilError,
19-
get_jit_arguments,
2019
jit_user_function,
2120
)
2221

@@ -43,6 +42,10 @@ def f(values, index, ...):
4342
------
4443
NumbaUtilError
4544
"""
45+
if not callable(func):
46+
raise NotImplementedError(
47+
"Numba engine can only be used with a single function."
48+
)
4649
udf_signature = list(inspect.signature(func).parameters.keys())
4750
expected_args = ["values", "index"]
4851
min_number_args = len(expected_args)
@@ -56,10 +59,12 @@ def f(values, index, ...):
5659
)
5760

5861

62+
@functools.lru_cache(maxsize=None)
5963
def generate_numba_agg_func(
60-
kwargs: dict[str, Any],
6164
func: Callable[..., Scalar],
62-
engine_kwargs: dict[str, bool] | None,
65+
nopython: bool,
66+
nogil: bool,
67+
parallel: bool,
6368
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
6469
"""
6570
Generate a numba jitted agg function specified by values from engine_kwargs.
@@ -72,24 +77,19 @@ def generate_numba_agg_func(
7277
7378
Parameters
7479
----------
75-
kwargs : dict
76-
**kwargs to be passed into the function
7780
func : function
78-
function to be applied to each window and will be JITed
79-
engine_kwargs : dict
80-
dictionary of arguments to be passed into numba.jit
81+
function to be applied to each group and will be JITed
82+
nopython : bool
83+
nopython to be passed into numba.jit
84+
nogil : bool
85+
nogil to be passed into numba.jit
86+
parallel : bool
87+
parallel to be passed into numba.jit
8188
8289
Returns
8390
-------
8491
Numba function
8592
"""
86-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
87-
88-
validate_udf(func)
89-
cache_key = (func, "groupby_agg")
90-
if cache_key in NUMBA_FUNC_CACHE:
91-
return NUMBA_FUNC_CACHE[cache_key]
92-
9393
numba_func = jit_user_function(func, nopython, nogil, parallel)
9494
if TYPE_CHECKING:
9595
import numba
@@ -120,10 +120,12 @@ def group_agg(
120120
return group_agg
121121

122122

123+
@functools.lru_cache(maxsize=None)
123124
def generate_numba_transform_func(
124-
kwargs: dict[str, Any],
125125
func: Callable[..., np.ndarray],
126-
engine_kwargs: dict[str, bool] | None,
126+
nopython: bool,
127+
nogil: bool,
128+
parallel: bool,
127129
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
128130
"""
129131
Generate a numba jitted transform function specified by values from engine_kwargs.
@@ -136,24 +138,19 @@ def generate_numba_transform_func(
136138
137139
Parameters
138140
----------
139-
kwargs : dict
140-
**kwargs to be passed into the function
141141
func : function
142142
function to be applied to each window and will be JITed
143-
engine_kwargs : dict
144-
dictionary of arguments to be passed into numba.jit
143+
nopython : bool
144+
nopython to be passed into numba.jit
145+
nogil : bool
146+
nogil to be passed into numba.jit
147+
parallel : bool
148+
parallel to be passed into numba.jit
145149
146150
Returns
147151
-------
148152
Numba function
149153
"""
150-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
151-
152-
validate_udf(func)
153-
cache_key = (func, "groupby_transform")
154-
if cache_key in NUMBA_FUNC_CACHE:
155-
return NUMBA_FUNC_CACHE[cache_key]
156-
157154
numba_func = jit_user_function(func, nopython, nogil, parallel)
158155
if TYPE_CHECKING:
159156
import numba

pandas/core/util/numba_.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from pandas.errors import NumbaUtilError
1414

1515
GLOBAL_USE_NUMBA: bool = False
16-
NUMBA_FUNC_CACHE: dict[tuple[Callable, str], Callable] = {}
1716

1817

1918
def maybe_use_numba(engine: str | None) -> bool:
@@ -30,7 +29,7 @@ def set_use_numba(enable: bool = False) -> None:
3029

3130
def get_jit_arguments(
3231
engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
33-
) -> tuple[bool, bool, bool]:
32+
) -> dict[str, bool]:
3433
"""
3534
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
3635
@@ -43,7 +42,7 @@ def get_jit_arguments(
4342
4443
Returns
4544
-------
46-
(bool, bool, bool)
45+
dict[str, bool]
4746
nopython, nogil, parallel
4847
4948
Raises
@@ -61,7 +60,7 @@ def get_jit_arguments(
6160
)
6261
nogil = engine_kwargs.get("nogil", False)
6362
parallel = engine_kwargs.get("parallel", False)
64-
return nopython, nogil, parallel
63+
return {"nopython": nopython, "nogil": nogil, "parallel": parallel}
6564

6665

6766
def jit_user_function(

0 commit comments

Comments
 (0)