Skip to content

Commit 068e654

Browse files
mroeschkeMatt Roeschke
and
Matt Roeschke
authored
PERF: Allow jitting of groupby agg loop (#35759)
* Roll back groupby agg changes * Add aggragate_with_numba * Fix cases where operation on Series inputs * Simplify case, handle Series correctly * Ensure function is being cached, validate the udf signature for groupby agg * Move some functionality to groupby/numba_.py * Change ValueError to NotImplementedError * Comment that it's only 1 function that is supported * Add whatsnew * Add issue number and correct typing * Add docstring for _aggregate_with_numba * Lint Co-authored-by: Matt Roeschke <[email protected]>
1 parent 3269f54 commit 068e654

File tree

7 files changed

+274
-180
lines changed

7 files changed

+274
-180
lines changed

doc/source/whatsnew/v1.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ Deprecations
154154
Performance improvements
155155
~~~~~~~~~~~~~~~~~~~~~~~~
156156

157-
-
157+
- Performance improvement in :meth:`GroupBy.agg` with the ``numba`` engine (:issue:`35759`)
158158
-
159159

160160
.. ---------------------------------------------------------------------------

pandas/core/groupby/generic.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,16 @@
7070
GroupBy,
7171
_agg_template,
7272
_apply_docs,
73+
_group_selection_context,
7374
_transform_template,
7475
get_groupby,
7576
)
77+
from pandas.core.groupby.numba_ import generate_numba_func, split_for_numba
7678
from pandas.core.indexes.api import Index, MultiIndex, all_indexes_same
7779
import pandas.core.indexes.base as ibase
7880
from pandas.core.internals import BlockManager, make_block
7981
from pandas.core.series import Series
80-
from pandas.core.util.numba_ import (
81-
NUMBA_FUNC_CACHE,
82-
generate_numba_func,
83-
maybe_use_numba,
84-
split_for_numba,
85-
)
82+
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE, maybe_use_numba
8683

8784
from pandas.plotting import boxplot_frame_groupby
8885

@@ -230,6 +227,18 @@ def apply(self, func, *args, **kwargs):
230227
)
231228
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
232229

230+
if maybe_use_numba(engine):
231+
if not callable(func):
232+
raise NotImplementedError(
233+
"Numba engine can only be used with a single function."
234+
)
235+
with _group_selection_context(self):
236+
data = self._selected_obj
237+
result, index = self._aggregate_with_numba(
238+
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
239+
)
240+
return self.obj._constructor(result.ravel(), index=index, name=data.name)
241+
233242
relabeling = func is None
234243
columns = None
235244
if relabeling:
@@ -252,16 +261,11 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
252261
return getattr(self, cyfunc)()
253262

254263
if self.grouper.nkeys > 1:
255-
return self._python_agg_general(
256-
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
257-
)
264+
return self._python_agg_general(func, *args, **kwargs)
258265

259266
try:
260-
return self._python_agg_general(
261-
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
262-
)
267+
return self._python_agg_general(func, *args, **kwargs)
263268
except (ValueError, KeyError):
264-
# Do not catch Numba errors here, we want to raise and not fall back.
265269
# TODO: KeyError is raised in _python_agg_general,
266270
# see see test_groupby.test_basic
267271
result = self._aggregate_named(func, *args, **kwargs)
@@ -937,12 +941,19 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
937941
)
938942
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
939943

940-
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
941-
942944
if maybe_use_numba(engine):
943-
return self._python_agg_general(
944-
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
945+
if not callable(func):
946+
raise NotImplementedError(
947+
"Numba engine can only be used with a single function."
948+
)
949+
with _group_selection_context(self):
950+
data = self._selected_obj
951+
result, index = self._aggregate_with_numba(
952+
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
945953
)
954+
return self.obj._constructor(result, index=index, columns=data.columns)
955+
956+
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
946957

947958
result, how = self._aggregate(func, *args, **kwargs)
948959
if how is None:

pandas/core/groupby/groupby.py

+46-24
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class providing the base-class of operations.
3434

3535
from pandas._config.config import option_context
3636

37-
from pandas._libs import Timestamp
37+
from pandas._libs import Timestamp, lib
3838
import pandas._libs.groupby as libgroupby
3939
from pandas._typing import F, FrameOrSeries, FrameOrSeriesUnion, Scalar
4040
from pandas.compat.numpy import function as nv
@@ -61,11 +61,11 @@ class providing the base-class of operations.
6161
import pandas.core.common as com
6262
from pandas.core.frame import DataFrame
6363
from pandas.core.generic import NDFrame
64-
from pandas.core.groupby import base, ops
64+
from pandas.core.groupby import base, numba_, ops
6565
from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex
6666
from pandas.core.series import Series
6767
from pandas.core.sorting import get_group_index_sorter
68-
from pandas.core.util.numba_ import maybe_use_numba
68+
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
6969

7070
_common_see_also = """
7171
See Also
@@ -384,7 +384,8 @@ class providing the base-class of operations.
384384
- dict of axis labels -> functions, function names or list of such.
385385
386386
Can also accept a Numba JIT function with
387-
``engine='numba'`` specified.
387+
``engine='numba'`` specified. Only passing a single function is supported
388+
with this engine.
388389
389390
If the ``'numba'`` engine is chosen, the function must be
390391
a user defined function with ``values`` and ``index`` as the
@@ -1053,12 +1054,43 @@ def _cython_agg_general(
10531054

10541055
return self._wrap_aggregated_output(output, index=self.grouper.result_index)
10551056

1056-
def _python_agg_general(
1057-
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
1058-
):
1057+
def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
1058+
"""
1059+
Perform groupby aggregation routine with the numba engine.
1060+
1061+
This routine mimics the data splitting routine of the DataSplitter class
1062+
to generate the indices of each group in the sorted data and then passes the
1063+
data and indices into a Numba jitted function.
1064+
"""
1065+
group_keys = self.grouper._get_group_keys()
1066+
labels, _, n_groups = self.grouper.group_info
1067+
sorted_index = get_group_index_sorter(labels, n_groups)
1068+
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
1069+
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
1070+
starts, ends = lib.generate_slices(sorted_labels, n_groups)
1071+
cache_key = (func, "groupby_agg")
1072+
if cache_key in NUMBA_FUNC_CACHE:
1073+
# Return an already compiled version of roll_apply if available
1074+
numba_agg_func = NUMBA_FUNC_CACHE[cache_key]
1075+
else:
1076+
numba_agg_func = numba_.generate_numba_agg_func(
1077+
tuple(args), kwargs, func, engine_kwargs
1078+
)
1079+
result = numba_agg_func(
1080+
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns),
1081+
)
1082+
if cache_key not in NUMBA_FUNC_CACHE:
1083+
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func
1084+
1085+
if self.grouper.nkeys > 1:
1086+
index = MultiIndex.from_tuples(group_keys, names=self.grouper.names)
1087+
else:
1088+
index = Index(group_keys, name=self.grouper.names[0])
1089+
return result, index
1090+
1091+
def _python_agg_general(self, func, *args, **kwargs):
10591092
func = self._is_builtin_func(func)
1060-
if engine != "numba":
1061-
f = lambda x: func(x, *args, **kwargs)
1093+
f = lambda x: func(x, *args, **kwargs)
10621094

10631095
# iterate through "columns" ex exclusions to populate output dict
10641096
output: Dict[base.OutputKey, np.ndarray] = {}
@@ -1069,21 +1101,11 @@ def _python_agg_general(
10691101
# agg_series below assumes ngroups > 0
10701102
continue
10711103

1072-
if maybe_use_numba(engine):
1073-
result, counts = self.grouper.agg_series(
1074-
obj,
1075-
func,
1076-
*args,
1077-
engine=engine,
1078-
engine_kwargs=engine_kwargs,
1079-
**kwargs,
1080-
)
1081-
else:
1082-
try:
1083-
# if this function is invalid for this dtype, we will ignore it.
1084-
result, counts = self.grouper.agg_series(obj, f)
1085-
except TypeError:
1086-
continue
1104+
try:
1105+
# if this function is invalid for this dtype, we will ignore it.
1106+
result, counts = self.grouper.agg_series(obj, f)
1107+
except TypeError:
1108+
continue
10871109

10881110
assert result is not None
10891111
key = base.OutputKey(label=name, position=idx)

pandas/core/groupby/numba_.py

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""Common utilities for Numba operations with groupby ops"""
2+
import inspect
3+
from typing import Any, Callable, Dict, Optional, Tuple
4+
5+
import numpy as np
6+
7+
from pandas._typing import FrameOrSeries, Scalar
8+
from pandas.compat._optional import import_optional_dependency
9+
10+
from pandas.core.util.numba_ import (
11+
NUMBA_FUNC_CACHE,
12+
NumbaUtilError,
13+
check_kwargs_and_nopython,
14+
get_jit_arguments,
15+
jit_user_function,
16+
)
17+
18+
19+
def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray]:
20+
"""
21+
Split pandas object into its components as numpy arrays for numba functions.
22+
23+
Parameters
24+
----------
25+
arg : Series or DataFrame
26+
27+
Returns
28+
-------
29+
(ndarray, ndarray)
30+
values, index
31+
"""
32+
return arg.to_numpy(), arg.index.to_numpy()
33+
34+
35+
def validate_udf(func: Callable) -> None:
36+
"""
37+
Validate user defined function for ops when using Numba with groupby ops.
38+
39+
The first signature arguments should include:
40+
41+
def f(values, index, ...):
42+
...
43+
44+
Parameters
45+
----------
46+
func : function, default False
47+
user defined function
48+
49+
Returns
50+
-------
51+
None
52+
53+
Raises
54+
------
55+
NumbaUtilError
56+
"""
57+
udf_signature = list(inspect.signature(func).parameters.keys())
58+
expected_args = ["values", "index"]
59+
min_number_args = len(expected_args)
60+
if (
61+
len(udf_signature) < min_number_args
62+
or udf_signature[:min_number_args] != expected_args
63+
):
64+
raise NumbaUtilError(
65+
f"The first {min_number_args} arguments to {func.__name__} must be "
66+
f"{expected_args}"
67+
)
68+
69+
70+
def generate_numba_func(
71+
func: Callable,
72+
engine_kwargs: Optional[Dict[str, bool]],
73+
kwargs: dict,
74+
cache_key_str: str,
75+
) -> Tuple[Callable, Tuple[Callable, str]]:
76+
"""
77+
Return a JITed function and cache key for the NUMBA_FUNC_CACHE
78+
79+
This _may_ be specific to groupby (as it's only used there currently).
80+
81+
Parameters
82+
----------
83+
func : function
84+
user defined function
85+
engine_kwargs : dict or None
86+
numba.jit arguments
87+
kwargs : dict
88+
kwargs for func
89+
cache_key_str : str
90+
string representing the second part of the cache key tuple
91+
92+
Returns
93+
-------
94+
(JITed function, cache key)
95+
96+
Raises
97+
------
98+
NumbaUtilError
99+
"""
100+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
101+
check_kwargs_and_nopython(kwargs, nopython)
102+
validate_udf(func)
103+
cache_key = (func, cache_key_str)
104+
numba_func = NUMBA_FUNC_CACHE.get(
105+
cache_key, jit_user_function(func, nopython, nogil, parallel)
106+
)
107+
return numba_func, cache_key
108+
109+
110+
def generate_numba_agg_func(
111+
args: Tuple,
112+
kwargs: Dict[str, Any],
113+
func: Callable[..., Scalar],
114+
engine_kwargs: Optional[Dict[str, bool]],
115+
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
116+
"""
117+
Generate a numba jitted agg function specified by values from engine_kwargs.
118+
119+
1. jit the user's function
120+
2. Return a groupby agg function with the jitted function inline
121+
122+
Configurations specified in engine_kwargs apply to both the user's
123+
function _AND_ the rolling apply function.
124+
125+
Parameters
126+
----------
127+
args : tuple
128+
*args to be passed into the function
129+
kwargs : dict
130+
**kwargs to be passed into the function
131+
func : function
132+
function to be applied to each window and will be JITed
133+
engine_kwargs : dict
134+
dictionary of arguments to be passed into numba.jit
135+
136+
Returns
137+
-------
138+
Numba function
139+
"""
140+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
141+
142+
check_kwargs_and_nopython(kwargs, nopython)
143+
144+
validate_udf(func)
145+
146+
numba_func = jit_user_function(func, nopython, nogil, parallel)
147+
148+
numba = import_optional_dependency("numba")
149+
150+
if parallel:
151+
loop_range = numba.prange
152+
else:
153+
loop_range = range
154+
155+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
156+
def group_apply(
157+
values: np.ndarray,
158+
index: np.ndarray,
159+
begin: np.ndarray,
160+
end: np.ndarray,
161+
num_groups: int,
162+
num_columns: int,
163+
) -> np.ndarray:
164+
result = np.empty((num_groups, num_columns))
165+
for i in loop_range(num_groups):
166+
group_index = index[begin[i] : end[i]]
167+
for j in loop_range(num_columns):
168+
group = values[begin[i] : end[i], j]
169+
result[i, j] = numba_func(group, group_index, *args)
170+
return result
171+
172+
return group_apply

0 commit comments

Comments
 (0)