Skip to content

Commit 2d84f49

Browse files
mroeschkerhshadrach
authored andcommitted
ENH: Add numba engine to groupby.aggregate (pandas-dev#33388)
1 parent 8fa0e2e commit 2d84f49

File tree

9 files changed

+341
-20
lines changed

9 files changed

+341
-20
lines changed

asv_bench/benchmarks/groupby.py

+58
Original file line numberDiff line numberDiff line change
@@ -660,4 +660,62 @@ def function(values):
660660
self.grouper.transform(function, engine="cython")
661661

662662

663+
class AggEngine:
664+
def setup(self):
665+
N = 10 ** 3
666+
data = DataFrame(
667+
{0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N},
668+
columns=[0, 1],
669+
)
670+
self.grouper = data.groupby(0)
671+
672+
def time_series_numba(self):
673+
def function(values, index):
674+
total = 0
675+
for i, value in enumerate(values):
676+
if i % 2:
677+
total += value + 5
678+
else:
679+
total += value * 2
680+
return total
681+
682+
self.grouper[1].agg(function, engine="numba")
683+
684+
def time_series_cython(self):
685+
def function(values):
686+
total = 0
687+
for i, value in enumerate(values):
688+
if i % 2:
689+
total += value + 5
690+
else:
691+
total += value * 2
692+
return total
693+
694+
self.grouper[1].agg(function, engine="cython")
695+
696+
def time_dataframe_numba(self):
697+
def function(values, index):
698+
total = 0
699+
for i, value in enumerate(values):
700+
if i % 2:
701+
total += value + 5
702+
else:
703+
total += value * 2
704+
return total
705+
706+
self.grouper.agg(function, engine="numba")
707+
708+
def time_dataframe_cython(self):
709+
def function(values):
710+
total = 0
711+
for i, value in enumerate(values):
712+
if i % 2:
713+
total += value + 5
714+
else:
715+
total += value * 2
716+
return total
717+
718+
self.grouper.agg(function, engine="cython")
719+
720+
663721
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/user_guide/computation.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,8 @@ and their default values are set to ``False``, ``True`` and ``False`` respective
382382
.. note::
383383

384384
In terms of performance, **the first time a function is run using the Numba engine will be slow**
385-
as Numba will have some function compilation overhead. However, ``rolling`` objects will cache
386-
the function and subsequent calls will be fast. In general, the Numba engine is performant with
385+
as Numba will have some function compilation overhead. However, the compiled functions are cached,
386+
and subsequent calls will be fast. In general, the Numba engine is performant with
387387
a larger amount of data points (e.g. 1+ million).
388388

389389
.. code-block:: ipython

doc/source/user_guide/groupby.rst

+67
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,73 @@ that is itself a series, and possibly upcast the result to a DataFrame:
10211021
the output as well as set the indices.
10221022

10231023

1024+
Numba Accelerated Routines
1025+
--------------------------
1026+
1027+
.. versionadded:: 1.1
1028+
1029+
If `Numba <https://numba.pydata.org/>`__ is installed as an optional dependency, the ``transform`` and
1030+
``aggregate`` methods support ``engine='numba'`` and ``engine_kwargs`` arguments. The ``engine_kwargs``
1031+
argument is a dictionary of keyword arguments that will be passed into the
1032+
`numba.jit decorator <https://numba.pydata.org/numba-doc/latest/reference/jit-compilation.html#numba.jit>`__.
1033+
These keyword arguments will be applied to the passed function. Currently only ``nogil``, ``nopython``,
1034+
and ``parallel`` are supported, and their default values are set to ``False``, ``True`` and ``False`` respectively.
1035+
1036+
The function signature must start with ``values, index`` **exactly** as the data belonging to each group
1037+
will be passed into ``values``, and the group index will be passed into ``index``.
1038+
1039+
.. warning::
1040+
1041+
When using ``engine='numba'``, there will be no "fall back" behavior internally. The group
1042+
data and group index will be passed as numpy arrays to the JITed user defined function, and no
1043+
alternative execution attempts will be tried.
1044+
1045+
.. note::
1046+
1047+
In terms of performance, **the first time a function is run using the Numba engine will be slow**
1048+
as Numba will have some function compilation overhead. However, the compiled functions are cached,
1049+
and subsequent calls will be fast. In general, the Numba engine is performant with
1050+
a larger amount of data points (e.g. 1+ million).
1051+
1052+
.. code-block:: ipython
1053+
1054+
In [1]: N = 10 ** 3
1055+
1056+
In [2]: data = {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N}
1057+
1058+
In [3]: df = pd.DataFrame(data, columns=[0, 1])
1059+
1060+
In [4]: def f_numba(values, index):
1061+
...: total = 0
1062+
...: for i, value in enumerate(values):
1063+
...: if i % 2:
1064+
...: total += value + 5
1065+
...: else:
1066+
...: total += value * 2
1067+
...: return total
1068+
...:
1069+
1070+
In [5]: def f_cython(values):
1071+
...: total = 0
1072+
...: for i, value in enumerate(values):
1073+
...: if i % 2:
1074+
...: total += value + 5
1075+
...: else:
1076+
...: total += value * 2
1077+
...: return total
1078+
...:
1079+
1080+
In [6]: groupby = df.groupby(0)
1081+
# Run the first time, compilation time will affect performance
1082+
In [7]: %timeit -r 1 -n 1 groupby.aggregate(f_numba, engine='numba') # noqa: E225
1083+
2.14 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
1084+
# Function is cached and performance will improve
1085+
In [8]: %timeit groupby.aggregate(f_numba, engine='numba')
1086+
4.93 ms ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1087+
1088+
In [9]: %timeit groupby.aggregate(f_cython, engine='cython')
1089+
18.6 ms ± 84.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1090+
10241091
Other useful features
10251092
---------------------
10261093

doc/source/whatsnew/v1.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Other enhancements
9898
This can be used to set a custom compression level, e.g.,
9999
``df.to_csv(path, compression={'method': 'gzip', 'compresslevel': 1}``
100100
(:issue:`33196`)
101-
- :meth:`~pandas.core.groupby.GroupBy.transform` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`)
101+
- :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`)
102102
- :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`)
103103
-
104104

pandas/core/groupby/generic.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
NUMBA_FUNC_CACHE,
8080
check_kwargs_and_nopython,
8181
get_jit_arguments,
82+
is_numba_util_related_error,
8283
jit_user_function,
8384
split_for_numba,
8485
validate_udf,
@@ -244,7 +245,9 @@ def apply(self, func, *args, **kwargs):
244245
axis="",
245246
)
246247
@Appender(_shared_docs["aggregate"])
247-
def aggregate(self, func=None, *args, **kwargs):
248+
def aggregate(
249+
self, func=None, *args, engine="cython", engine_kwargs=None, **kwargs
250+
):
248251

249252
relabeling = func is None
250253
columns = None
@@ -272,11 +275,18 @@ def aggregate(self, func=None, *args, **kwargs):
272275
return getattr(self, cyfunc)()
273276

274277
if self.grouper.nkeys > 1:
275-
return self._python_agg_general(func, *args, **kwargs)
278+
return self._python_agg_general(
279+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
280+
)
276281

277282
try:
278-
return self._python_agg_general(func, *args, **kwargs)
279-
except (ValueError, KeyError):
283+
return self._python_agg_general(
284+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
285+
)
286+
except (ValueError, KeyError) as err:
287+
# Do not catch Numba errors here, we want to raise and not fall back.
288+
if is_numba_util_related_error(str(err)):
289+
raise err
280290
# TODO: KeyError is raised in _python_agg_general,
281291
# see see test_groupby.test_basic
282292
result = self._aggregate_named(func, *args, **kwargs)
@@ -941,7 +951,9 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
941951
axis="",
942952
)
943953
@Appender(_shared_docs["aggregate"])
944-
def aggregate(self, func=None, *args, **kwargs):
954+
def aggregate(
955+
self, func=None, *args, engine="cython", engine_kwargs=None, **kwargs
956+
):
945957

946958
relabeling = func is None and is_multi_agg_with_relabel(**kwargs)
947959
if relabeling:
@@ -962,6 +974,11 @@ def aggregate(self, func=None, *args, **kwargs):
962974

963975
func = maybe_mangle_lambdas(func)
964976

977+
if engine == "numba":
978+
return self._python_agg_general(
979+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
980+
)
981+
965982
result, how = self._aggregate(func, *args, **kwargs)
966983
if how is None:
967984
return result

pandas/core/groupby/groupby.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -910,9 +910,12 @@ def _cython_agg_general(
910910

911911
return self._wrap_aggregated_output(output)
912912

913-
def _python_agg_general(self, func, *args, **kwargs):
913+
def _python_agg_general(
914+
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
915+
):
914916
func = self._is_builtin_func(func)
915-
f = lambda x: func(x, *args, **kwargs)
917+
if engine != "numba":
918+
f = lambda x: func(x, *args, **kwargs)
916919

917920
# iterate through "columns" ex exclusions to populate output dict
918921
output: Dict[base.OutputKey, np.ndarray] = {}
@@ -923,11 +926,21 @@ def _python_agg_general(self, func, *args, **kwargs):
923926
# agg_series below assumes ngroups > 0
924927
continue
925928

926-
try:
927-
# if this function is invalid for this dtype, we will ignore it.
928-
result, counts = self.grouper.agg_series(obj, f)
929-
except TypeError:
930-
continue
929+
if engine == "numba":
930+
result, counts = self.grouper.agg_series(
931+
obj,
932+
func,
933+
*args,
934+
engine=engine,
935+
engine_kwargs=engine_kwargs,
936+
**kwargs,
937+
)
938+
else:
939+
try:
940+
# if this function is invalid for this dtype, we will ignore it.
941+
result, counts = self.grouper.agg_series(obj, f)
942+
except TypeError:
943+
continue
931944

932945
assert result is not None
933946
key = base.OutputKey(label=name, position=idx)

pandas/core/groupby/ops.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@
5454
get_group_index_sorter,
5555
get_indexer_dict,
5656
)
57+
from pandas.core.util.numba_ import (
58+
NUMBA_FUNC_CACHE,
59+
check_kwargs_and_nopython,
60+
get_jit_arguments,
61+
jit_user_function,
62+
split_for_numba,
63+
validate_udf,
64+
)
5765

5866

5967
class BaseGrouper:
@@ -608,10 +616,16 @@ def _transform(
608616

609617
return result
610618

611-
def agg_series(self, obj: Series, func):
619+
def agg_series(
620+
self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
621+
):
612622
# Caller is responsible for checking ngroups != 0
613623
assert self.ngroups != 0
614624

625+
if engine == "numba":
626+
return self._aggregate_series_pure_python(
627+
obj, func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
628+
)
615629
if len(obj) == 0:
616630
# SeriesGrouper would raise if we were to call _aggregate_series_fast
617631
return self._aggregate_series_pure_python(obj, func)
@@ -656,7 +670,18 @@ def _aggregate_series_fast(self, obj: Series, func):
656670
result, counts = grouper.get_result()
657671
return result, counts
658672

659-
def _aggregate_series_pure_python(self, obj: Series, func):
673+
def _aggregate_series_pure_python(
674+
self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
675+
):
676+
677+
if engine == "numba":
678+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
679+
check_kwargs_and_nopython(kwargs, nopython)
680+
validate_udf(func)
681+
cache_key = (func, "groupby_agg")
682+
numba_func = NUMBA_FUNC_CACHE.get(
683+
cache_key, jit_user_function(func, nopython, nogil, parallel)
684+
)
660685

661686
group_index, _, ngroups = self.group_info
662687

@@ -666,7 +691,14 @@ def _aggregate_series_pure_python(self, obj: Series, func):
666691
splitter = get_splitter(obj, group_index, ngroups, axis=0)
667692

668693
for label, group in splitter:
669-
res = func(group)
694+
if engine == "numba":
695+
values, index = split_for_numba(group)
696+
res = numba_func(values, index, *args)
697+
if cache_key not in NUMBA_FUNC_CACHE:
698+
NUMBA_FUNC_CACHE[cache_key] = numba_func
699+
else:
700+
res = func(group, *args, **kwargs)
701+
670702
if result is None:
671703
if isinstance(res, (Series, Index, np.ndarray)):
672704
if len(res) == 1:
@@ -842,7 +874,9 @@ def groupings(self) -> "List[grouper.Grouping]":
842874
for lvl, name in zip(self.levels, self.names)
843875
]
844876

845-
def agg_series(self, obj: Series, func):
877+
def agg_series(
878+
self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
879+
):
846880
# Caller is responsible for checking ngroups != 0
847881
assert self.ngroups != 0
848882
assert len(self.bins) > 0 # otherwise we'd get IndexError in get_result

pandas/core/util/numba_.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,25 @@
1212
NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = dict()
1313

1414

15+
def is_numba_util_related_error(err_message: str) -> bool:
16+
"""
17+
Check if an error was raised from one of the numba utility functions
18+
19+
For cases where a try/except block has mistakenly caught the error
20+
and we want to re-raise
21+
22+
Parameters
23+
----------
24+
err_message : str,
25+
exception error message
26+
27+
Returns
28+
-------
29+
bool
30+
"""
31+
return "The first" in err_message or "numba does not" in err_message
32+
33+
1534
def check_kwargs_and_nopython(
1635
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None
1736
) -> None:
@@ -76,7 +95,6 @@ def jit_user_function(
7695
----------
7796
func : function
7897
user defined function
79-
8098
nopython : bool
8199
nopython parameter for numba.JIT
82100
nogil : bool

0 commit comments

Comments
 (0)