Skip to content

Commit 1b6879a

Browse files
authored
CLN: Numba internal routines (#36376)
1 parent 11d5fc9 commit 1b6879a

File tree

6 files changed

+47
-147
lines changed

6 files changed

+47
-147
lines changed

pandas/core/groupby/generic.py

+3-17
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,11 @@
7474
get_groupby,
7575
group_selection_context,
7676
)
77-
from pandas.core.groupby.numba_ import generate_numba_func, split_for_numba
7877
from pandas.core.indexes.api import Index, MultiIndex, all_indexes_same
7978
import pandas.core.indexes.base as ibase
8079
from pandas.core.internals import BlockManager
8180
from pandas.core.series import Series
82-
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE, maybe_use_numba
81+
from pandas.core.util.numba_ import maybe_use_numba
8382

8483
from pandas.plotting import boxplot_frame_groupby
8584

@@ -518,29 +517,16 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
518517
result = getattr(self, func)(*args, **kwargs)
519518
return self._transform_fast(result)
520519

521-
def _transform_general(
522-
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
523-
):
520+
def _transform_general(self, func, *args, **kwargs):
524521
"""
525522
Transform with a non-str `func`.
526523
"""
527-
if maybe_use_numba(engine):
528-
numba_func, cache_key = generate_numba_func(
529-
func, engine_kwargs, kwargs, "groupby_transform"
530-
)
531-
532524
klass = type(self._selected_obj)
533525

534526
results = []
535527
for name, group in self:
536528
object.__setattr__(group, "name", name)
537-
if maybe_use_numba(engine):
538-
values, index = split_for_numba(group)
539-
res = numba_func(values, index, *args)
540-
if cache_key not in NUMBA_FUNC_CACHE:
541-
NUMBA_FUNC_CACHE[cache_key] = numba_func
542-
else:
543-
res = func(group, *args, **kwargs)
529+
res = func(group, *args, **kwargs)
544530

545531
if isinstance(res, (ABCDataFrame, ABCSeries)):
546532
res = res._values

pandas/core/groupby/groupby.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -1071,16 +1071,15 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
10711071
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
10721072
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
10731073
starts, ends = lib.generate_slices(sorted_labels, n_groups)
1074-
cache_key = (func, "groupby_transform")
1075-
if cache_key in NUMBA_FUNC_CACHE:
1076-
numba_transform_func = NUMBA_FUNC_CACHE[cache_key]
1077-
else:
1078-
numba_transform_func = numba_.generate_numba_transform_func(
1079-
tuple(args), kwargs, func, engine_kwargs
1080-
)
1074+
1075+
numba_transform_func = numba_.generate_numba_transform_func(
1076+
tuple(args), kwargs, func, engine_kwargs
1077+
)
10811078
result = numba_transform_func(
10821079
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns)
10831080
)
1081+
1082+
cache_key = (func, "groupby_transform")
10841083
if cache_key not in NUMBA_FUNC_CACHE:
10851084
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func
10861085

@@ -1106,16 +1105,15 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
11061105
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
11071106
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
11081107
starts, ends = lib.generate_slices(sorted_labels, n_groups)
1109-
cache_key = (func, "groupby_agg")
1110-
if cache_key in NUMBA_FUNC_CACHE:
1111-
numba_agg_func = NUMBA_FUNC_CACHE[cache_key]
1112-
else:
1113-
numba_agg_func = numba_.generate_numba_agg_func(
1114-
tuple(args), kwargs, func, engine_kwargs
1115-
)
1108+
1109+
numba_agg_func = numba_.generate_numba_agg_func(
1110+
tuple(args), kwargs, func, engine_kwargs
1111+
)
11161112
result = numba_agg_func(
11171113
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns)
11181114
)
1115+
1116+
cache_key = (func, "groupby_agg")
11191117
if cache_key not in NUMBA_FUNC_CACHE:
11201118
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func
11211119

pandas/core/groupby/numba_.py

+13-72
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,17 @@
44

55
import numpy as np
66

7-
from pandas._typing import FrameOrSeries, Scalar
7+
from pandas._typing import Scalar
88
from pandas.compat._optional import import_optional_dependency
99

1010
from pandas.core.util.numba_ import (
1111
NUMBA_FUNC_CACHE,
1212
NumbaUtilError,
13-
check_kwargs_and_nopython,
1413
get_jit_arguments,
1514
jit_user_function,
1615
)
1716

1817

19-
def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray]:
20-
"""
21-
Split pandas object into its components as numpy arrays for numba functions.
22-
23-
Parameters
24-
----------
25-
arg : Series or DataFrame
26-
27-
Returns
28-
-------
29-
(ndarray, ndarray)
30-
values, index
31-
"""
32-
return arg.to_numpy(), arg.index.to_numpy()
33-
34-
3518
def validate_udf(func: Callable) -> None:
3619
"""
3720
Validate user defined function for ops when using Numba with groupby ops.
@@ -67,46 +50,6 @@ def f(values, index, ...):
6750
)
6851

6952

70-
def generate_numba_func(
71-
func: Callable,
72-
engine_kwargs: Optional[Dict[str, bool]],
73-
kwargs: dict,
74-
cache_key_str: str,
75-
) -> Tuple[Callable, Tuple[Callable, str]]:
76-
"""
77-
Return a JITed function and cache key for the NUMBA_FUNC_CACHE
78-
79-
This _may_ be specific to groupby (as it's only used there currently).
80-
81-
Parameters
82-
----------
83-
func : function
84-
user defined function
85-
engine_kwargs : dict or None
86-
numba.jit arguments
87-
kwargs : dict
88-
kwargs for func
89-
cache_key_str : str
90-
string representing the second part of the cache key tuple
91-
92-
Returns
93-
-------
94-
(JITed function, cache key)
95-
96-
Raises
97-
------
98-
NumbaUtilError
99-
"""
100-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
101-
check_kwargs_and_nopython(kwargs, nopython)
102-
validate_udf(func)
103-
cache_key = (func, cache_key_str)
104-
numba_func = NUMBA_FUNC_CACHE.get(
105-
cache_key, jit_user_function(func, nopython, nogil, parallel)
106-
)
107-
return numba_func, cache_key
108-
109-
11053
def generate_numba_agg_func(
11154
args: Tuple,
11255
kwargs: Dict[str, Any],
@@ -120,7 +63,7 @@ def generate_numba_agg_func(
12063
2. Return a groupby agg function with the jitted function inline
12164
12265
Configurations specified in engine_kwargs apply to both the user's
123-
function _AND_ the rolling apply function.
66+
function _AND_ the groupby evaluation loop.
12467
12568
Parameters
12669
----------
@@ -137,16 +80,15 @@ def generate_numba_agg_func(
13780
-------
13881
Numba function
13982
"""
140-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
141-
142-
check_kwargs_and_nopython(kwargs, nopython)
83+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
14384

14485
validate_udf(func)
86+
cache_key = (func, "groupby_agg")
87+
if cache_key in NUMBA_FUNC_CACHE:
88+
return NUMBA_FUNC_CACHE[cache_key]
14589

14690
numba_func = jit_user_function(func, nopython, nogil, parallel)
147-
14891
numba = import_optional_dependency("numba")
149-
15092
if parallel:
15193
loop_range = numba.prange
15294
else:
@@ -175,17 +117,17 @@ def group_agg(
175117
def generate_numba_transform_func(
176118
args: Tuple,
177119
kwargs: Dict[str, Any],
178-
func: Callable[..., Scalar],
120+
func: Callable[..., np.ndarray],
179121
engine_kwargs: Optional[Dict[str, bool]],
180122
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
181123
"""
182124
Generate a numba jitted transform function specified by values from engine_kwargs.
183125
184126
1. jit the user's function
185-
2. Return a groupby agg function with the jitted function inline
127+
2. Return a groupby transform function with the jitted function inline
186128
187129
Configurations specified in engine_kwargs apply to both the user's
188-
function _AND_ the rolling apply function.
130+
function _AND_ the groupby evaluation loop.
189131
190132
Parameters
191133
----------
@@ -202,16 +144,15 @@ def generate_numba_transform_func(
202144
-------
203145
Numba function
204146
"""
205-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
206-
207-
check_kwargs_and_nopython(kwargs, nopython)
147+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
208148

209149
validate_udf(func)
150+
cache_key = (func, "groupby_transform")
151+
if cache_key in NUMBA_FUNC_CACHE:
152+
return NUMBA_FUNC_CACHE[cache_key]
210153

211154
numba_func = jit_user_function(func, nopython, nogil, parallel)
212-
213155
numba = import_optional_dependency("numba")
214-
215156
if parallel:
216157
loop_range = numba.prange
217158
else:

pandas/core/util/numba_.py

+12-30
Original file line numberDiff line numberDiff line change
@@ -24,37 +24,8 @@ def set_use_numba(enable: bool = False) -> None:
2424
GLOBAL_USE_NUMBA = enable
2525

2626

27-
def check_kwargs_and_nopython(
28-
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None
29-
) -> None:
30-
"""
31-
Validate that **kwargs and nopython=True was passed
32-
https://github.com/numba/numba/issues/2916
33-
34-
Parameters
35-
----------
36-
kwargs : dict, default None
37-
user passed keyword arguments to pass into the JITed function
38-
nopython : bool, default None
39-
nopython parameter
40-
41-
Returns
42-
-------
43-
None
44-
45-
Raises
46-
------
47-
NumbaUtilError
48-
"""
49-
if kwargs and nopython:
50-
raise NumbaUtilError(
51-
"numba does not support kwargs with nopython=True: "
52-
"https://github.com/numba/numba/issues/2916"
53-
)
54-
55-
5627
def get_jit_arguments(
57-
engine_kwargs: Optional[Dict[str, bool]] = None
28+
engine_kwargs: Optional[Dict[str, bool]] = None, kwargs: Optional[Dict] = None,
5829
) -> Tuple[bool, bool, bool]:
5930
"""
6031
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
@@ -63,16 +34,27 @@ def get_jit_arguments(
6334
----------
6435
engine_kwargs : dict, default None
6536
user passed keyword arguments for numba.JIT
37+
kwargs : dict, default None
38+
user passed keyword arguments to pass into the JITed function
6639
6740
Returns
6841
-------
6942
(bool, bool, bool)
7043
nopython, nogil, parallel
44+
45+
Raises
46+
------
47+
NumbaUtilError
7148
"""
7249
if engine_kwargs is None:
7350
engine_kwargs = {}
7451

7552
nopython = engine_kwargs.get("nopython", True)
53+
if kwargs and nopython:
54+
raise NumbaUtilError(
55+
"numba does not support kwargs with nopython=True: "
56+
"https://github.com/numba/numba/issues/2916"
57+
)
7658
nogil = engine_kwargs.get("nogil", False)
7759
parallel = engine_kwargs.get("parallel", False)
7860
return nopython, nogil, parallel

pandas/core/window/numba_.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pandas.compat._optional import import_optional_dependency
77

88
from pandas.core.util.numba_ import (
9-
check_kwargs_and_nopython,
9+
NUMBA_FUNC_CACHE,
1010
get_jit_arguments,
1111
jit_user_function,
1212
)
@@ -42,14 +42,14 @@ def generate_numba_apply_func(
4242
-------
4343
Numba function
4444
"""
45-
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
45+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)
4646

47-
check_kwargs_and_nopython(kwargs, nopython)
47+
cache_key = (func, "rolling_apply")
48+
if cache_key in NUMBA_FUNC_CACHE:
49+
return NUMBA_FUNC_CACHE[cache_key]
4850

4951
numba_func = jit_user_function(func, nopython, nogil, parallel)
50-
5152
numba = import_optional_dependency("numba")
52-
5353
if parallel:
5454
loop_range = numba.prange
5555
else:

pandas/core/window/rolling.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -1374,14 +1374,7 @@ def apply(
13741374
if maybe_use_numba(engine):
13751375
if raw is False:
13761376
raise ValueError("raw must be `True` when using the numba engine")
1377-
cache_key = (func, "rolling_apply")
1378-
if cache_key in NUMBA_FUNC_CACHE:
1379-
# Return an already compiled version of roll_apply if available
1380-
apply_func = NUMBA_FUNC_CACHE[cache_key]
1381-
else:
1382-
apply_func = generate_numba_apply_func(
1383-
args, kwargs, func, engine_kwargs
1384-
)
1377+
apply_func = generate_numba_apply_func(args, kwargs, func, engine_kwargs)
13851378
center = self.center
13861379
elif engine in ("cython", None):
13871380
if engine_kwargs is not None:
@@ -1403,7 +1396,7 @@ def apply(
14031396
center=center,
14041397
floor=0,
14051398
name=func,
1406-
use_numba_cache=engine == "numba",
1399+
use_numba_cache=maybe_use_numba(engine),
14071400
raw=raw,
14081401
original_func=func,
14091402
args=args,

0 commit comments

Comments
 (0)