diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py
index eb637c78806c0..c9ac275cc4ea7 100644
--- a/asv_bench/benchmarks/groupby.py
+++ b/asv_bench/benchmarks/groupby.py
@@ -660,4 +660,62 @@ def function(values):
self.grouper.transform(function, engine="cython")
+class AggEngine:
+ def setup(self):
+ N = 10 ** 3
+ data = DataFrame(
+ {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N},
+ columns=[0, 1],
+ )
+ self.grouper = data.groupby(0)
+
+ def time_series_numba(self):
+ def function(values, index):
+ total = 0
+ for i, value in enumerate(values):
+ if i % 2:
+ total += value + 5
+ else:
+ total += value * 2
+ return total
+
+ self.grouper[1].agg(function, engine="numba")
+
+ def time_series_cython(self):
+ def function(values):
+ total = 0
+ for i, value in enumerate(values):
+ if i % 2:
+ total += value + 5
+ else:
+ total += value * 2
+ return total
+
+ self.grouper[1].agg(function, engine="cython")
+
+ def time_dataframe_numba(self):
+ def function(values, index):
+ total = 0
+ for i, value in enumerate(values):
+ if i % 2:
+ total += value + 5
+ else:
+ total += value * 2
+ return total
+
+ self.grouper.agg(function, engine="numba")
+
+ def time_dataframe_cython(self):
+ def function(values):
+ total = 0
+ for i, value in enumerate(values):
+ if i % 2:
+ total += value + 5
+ else:
+ total += value * 2
+ return total
+
+ self.grouper.agg(function, engine="cython")
+
+
from .pandas_vb_common import setup # noqa: F401 isort:skip
diff --git a/doc/source/user_guide/computation.rst b/doc/source/user_guide/computation.rst
index d7d025981f2f4..37ec7ca9c98d6 100644
--- a/doc/source/user_guide/computation.rst
+++ b/doc/source/user_guide/computation.rst
@@ -380,8 +380,8 @@ and their default values are set to ``False``, ``True`` and ``False`` respective
.. note::
In terms of performance, **the first time a function is run using the Numba engine will be slow**
- as Numba will have some function compilation overhead. However, ``rolling`` objects will cache
- the function and subsequent calls will be fast. In general, the Numba engine is performant with
+ as Numba will have some function compilation overhead. However, the compiled functions are cached,
+ and subsequent calls will be fast. In general, the Numba engine is performant with
a larger amount of data points (e.g. 1+ million).
.. code-block:: ipython
diff --git a/doc/source/user_guide/groupby.rst b/doc/source/user_guide/groupby.rst
index 5927f1a4175ee..c5f58425139ee 100644
--- a/doc/source/user_guide/groupby.rst
+++ b/doc/source/user_guide/groupby.rst
@@ -1021,6 +1021,73 @@ that is itself a series, and possibly upcast the result to a DataFrame:
the output as well as set the indices.
+Numba Accelerated Routines
+--------------------------
+
+.. versionadded:: 1.1
+
+If `Numba `__ is installed as an optional dependency, the ``transform`` and
+``aggregate`` methods support ``engine='numba'`` and ``engine_kwargs`` arguments. The ``engine_kwargs``
+argument is a dictionary of keyword arguments that will be passed into the
+`numba.jit decorator `__.
+These keyword arguments will be applied to the passed function. Currently only ``nogil``, ``nopython``,
+and ``parallel`` are supported, and their default values are set to ``False``, ``True`` and ``False`` respectively.
+
+The function signature must start with ``values, index`` **exactly** as the data belonging to each group
+will be passed into ``values``, and the group index will be passed into ``index``.
+
+.. warning::
+
+ When using ``engine='numba'``, there will be no "fall back" behavior internally. The group
+ data and group index will be passed as numpy arrays to the JITed user defined function, and no
+ alternative execution attempts will be tried.
+
+.. note::
+
+ In terms of performance, **the first time a function is run using the Numba engine will be slow**
+ as Numba will have some function compilation overhead. However, the compiled functions are cached,
+ and subsequent calls will be fast. In general, the Numba engine is performant with
+ a larger amount of data points (e.g. 1+ million).
+
+.. code-block:: ipython
+
+ In [1]: N = 10 ** 3
+
+ In [2]: data = {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N}
+
+ In [3]: df = pd.DataFrame(data, columns=[0, 1])
+
+ In [4]: def f_numba(values, index):
+ ...: total = 0
+ ...: for i, value in enumerate(values):
+ ...: if i % 2:
+ ...: total += value + 5
+ ...: else:
+ ...: total += value * 2
+ ...: return total
+ ...:
+
+ In [5]: def f_cython(values):
+ ...: total = 0
+ ...: for i, value in enumerate(values):
+ ...: if i % 2:
+ ...: total += value + 5
+ ...: else:
+ ...: total += value * 2
+ ...: return total
+ ...:
+
+ In [6]: groupby = df.groupby(0)
+ # Run the first time, compilation time will affect performance
+ In [7]: %timeit -r 1 -n 1 groupby.aggregate(f_numba, engine='numba') # noqa: E225
+ 2.14 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
+ # Function is cached and performance will improve
+ In [8]: %timeit groupby.aggregate(f_numba, engine='numba')
+ 4.93 ms ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
+
+ In [9]: %timeit groupby.aggregate(f_cython, engine='cython')
+ 18.6 ms ± 84.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
+
Other useful features
---------------------
diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst
index cd1cb0b64f74a..0f766414b20e8 100644
--- a/doc/source/whatsnew/v1.1.0.rst
+++ b/doc/source/whatsnew/v1.1.0.rst
@@ -98,7 +98,7 @@ Other enhancements
This can be used to set a custom compression level, e.g.,
``df.to_csv(path, compression={'method': 'gzip', 'compresslevel': 1}``
(:issue:`33196`)
-- :meth:`~pandas.core.groupby.GroupBy.transform` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`)
+- :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`)
- :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`)
-
diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py
index 504de404b2509..18752cdc1642e 100644
--- a/pandas/core/groupby/generic.py
+++ b/pandas/core/groupby/generic.py
@@ -79,6 +79,7 @@
NUMBA_FUNC_CACHE,
check_kwargs_and_nopython,
get_jit_arguments,
+ is_numba_util_related_error,
jit_user_function,
split_for_numba,
validate_udf,
@@ -244,7 +245,9 @@ def apply(self, func, *args, **kwargs):
axis="",
)
@Appender(_shared_docs["aggregate"])
- def aggregate(self, func=None, *args, **kwargs):
+ def aggregate(
+ self, func=None, *args, engine="cython", engine_kwargs=None, **kwargs
+ ):
relabeling = func is None
columns = None
@@ -272,11 +275,18 @@ def aggregate(self, func=None, *args, **kwargs):
return getattr(self, cyfunc)()
if self.grouper.nkeys > 1:
- return self._python_agg_general(func, *args, **kwargs)
+ return self._python_agg_general(
+ func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
+ )
try:
- return self._python_agg_general(func, *args, **kwargs)
- except (ValueError, KeyError):
+ return self._python_agg_general(
+ func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
+ )
+ except (ValueError, KeyError) as err:
+ # Do not catch Numba errors here, we want to raise and not fall back.
+ if is_numba_util_related_error(str(err)):
+ raise err
# TODO: KeyError is raised in _python_agg_general,
# see see test_groupby.test_basic
result = self._aggregate_named(func, *args, **kwargs)
@@ -941,7 +951,9 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
axis="",
)
@Appender(_shared_docs["aggregate"])
- def aggregate(self, func=None, *args, **kwargs):
+ def aggregate(
+ self, func=None, *args, engine="cython", engine_kwargs=None, **kwargs
+ ):
relabeling = func is None and is_multi_agg_with_relabel(**kwargs)
if relabeling:
@@ -962,6 +974,11 @@ def aggregate(self, func=None, *args, **kwargs):
func = maybe_mangle_lambdas(func)
+ if engine == "numba":
+ return self._python_agg_general(
+ func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
+ )
+
result, how = self._aggregate(func, *args, **kwargs)
if how is None:
return result
diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py
index 154af3981a5ff..6924c7d320bc4 100644
--- a/pandas/core/groupby/groupby.py
+++ b/pandas/core/groupby/groupby.py
@@ -910,9 +910,12 @@ def _cython_agg_general(
return self._wrap_aggregated_output(output)
- def _python_agg_general(self, func, *args, **kwargs):
+ def _python_agg_general(
+ self, func, *args, engine="cython", engine_kwargs=None, **kwargs
+ ):
func = self._is_builtin_func(func)
- f = lambda x: func(x, *args, **kwargs)
+ if engine != "numba":
+ f = lambda x: func(x, *args, **kwargs)
# iterate through "columns" ex exclusions to populate output dict
output: Dict[base.OutputKey, np.ndarray] = {}
@@ -923,11 +926,21 @@ def _python_agg_general(self, func, *args, **kwargs):
# agg_series below assumes ngroups > 0
continue
- try:
- # if this function is invalid for this dtype, we will ignore it.
- result, counts = self.grouper.agg_series(obj, f)
- except TypeError:
- continue
+ if engine == "numba":
+ result, counts = self.grouper.agg_series(
+ obj,
+ func,
+ *args,
+ engine=engine,
+ engine_kwargs=engine_kwargs,
+ **kwargs,
+ )
+ else:
+ try:
+ # if this function is invalid for this dtype, we will ignore it.
+ result, counts = self.grouper.agg_series(obj, f)
+ except TypeError:
+ continue
assert result is not None
key = base.OutputKey(label=name, position=idx)
diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py
index 8d535374a083f..3c7794fa52d86 100644
--- a/pandas/core/groupby/ops.py
+++ b/pandas/core/groupby/ops.py
@@ -54,6 +54,14 @@
get_group_index_sorter,
get_indexer_dict,
)
+from pandas.core.util.numba_ import (
+ NUMBA_FUNC_CACHE,
+ check_kwargs_and_nopython,
+ get_jit_arguments,
+ jit_user_function,
+ split_for_numba,
+ validate_udf,
+)
class BaseGrouper:
@@ -608,10 +616,16 @@ def _transform(
return result
- def agg_series(self, obj: Series, func):
+ def agg_series(
+ self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
+ ):
# Caller is responsible for checking ngroups != 0
assert self.ngroups != 0
+ if engine == "numba":
+ return self._aggregate_series_pure_python(
+ obj, func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
+ )
if len(obj) == 0:
# SeriesGrouper would raise if we were to call _aggregate_series_fast
return self._aggregate_series_pure_python(obj, func)
@@ -656,7 +670,18 @@ def _aggregate_series_fast(self, obj: Series, func):
result, counts = grouper.get_result()
return result, counts
- def _aggregate_series_pure_python(self, obj: Series, func):
+ def _aggregate_series_pure_python(
+ self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
+ ):
+
+ if engine == "numba":
+ nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
+ check_kwargs_and_nopython(kwargs, nopython)
+ validate_udf(func)
+ cache_key = (func, "groupby_agg")
+ numba_func = NUMBA_FUNC_CACHE.get(
+ cache_key, jit_user_function(func, nopython, nogil, parallel)
+ )
group_index, _, ngroups = self.group_info
@@ -666,7 +691,14 @@ def _aggregate_series_pure_python(self, obj: Series, func):
splitter = get_splitter(obj, group_index, ngroups, axis=0)
for label, group in splitter:
- res = func(group)
+ if engine == "numba":
+ values, index = split_for_numba(group)
+ res = numba_func(values, index, *args)
+ if cache_key not in NUMBA_FUNC_CACHE:
+ NUMBA_FUNC_CACHE[cache_key] = numba_func
+ else:
+ res = func(group, *args, **kwargs)
+
if result is None:
if isinstance(res, (Series, Index, np.ndarray)):
if len(res) == 1:
@@ -842,7 +874,9 @@ def groupings(self) -> "List[grouper.Grouping]":
for lvl, name in zip(self.levels, self.names)
]
- def agg_series(self, obj: Series, func):
+ def agg_series(
+ self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
+ ):
# Caller is responsible for checking ngroups != 0
assert self.ngroups != 0
assert len(self.bins) > 0 # otherwise we'd get IndexError in get_result
diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py
index 29e74747881ae..215248f5a43c2 100644
--- a/pandas/core/util/numba_.py
+++ b/pandas/core/util/numba_.py
@@ -12,6 +12,25 @@
NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = dict()
+def is_numba_util_related_error(err_message: str) -> bool:
+ """
+ Check if an error was raised from one of the numba utility functions
+
+ For cases where a try/except block has mistakenly caught the error
+ and we want to re-raise
+
+ Parameters
+ ----------
+ err_message : str,
+ exception error message
+
+ Returns
+ -------
+ bool
+ """
+ return "The first" in err_message or "numba does not" in err_message
+
+
def check_kwargs_and_nopython(
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None
) -> None:
@@ -76,7 +95,6 @@ def jit_user_function(
----------
func : function
user defined function
-
nopython : bool
nopython parameter for numba.JIT
nogil : bool
diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py
new file mode 100644
index 0000000000000..70b0a027f1bd1
--- /dev/null
+++ b/pandas/tests/groupby/aggregate/test_numba.py
@@ -0,0 +1,114 @@
+import numpy as np
+import pytest
+
+import pandas.util._test_decorators as td
+
+from pandas import DataFrame
+import pandas._testing as tm
+from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
+
+
+@td.skip_if_no("numba", "0.46.0")
+def test_correct_function_signature():
+ def incorrect_function(x):
+ return sum(x) * 2.7
+
+ data = DataFrame(
+ {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
+ columns=["key", "data"],
+ )
+ with pytest.raises(ValueError, match=f"The first 2"):
+ data.groupby("key").agg(incorrect_function, engine="numba")
+
+ with pytest.raises(ValueError, match=f"The first 2"):
+ data.groupby("key")["data"].agg(incorrect_function, engine="numba")
+
+
+@td.skip_if_no("numba", "0.46.0")
+def test_check_nopython_kwargs():
+ def incorrect_function(x, **kwargs):
+ return sum(x) * 2.7
+
+ data = DataFrame(
+ {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
+ columns=["key", "data"],
+ )
+ with pytest.raises(ValueError, match="numba does not support"):
+ data.groupby("key").agg(incorrect_function, engine="numba", a=1)
+
+ with pytest.raises(ValueError, match="numba does not support"):
+ data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1)
+
+
+@td.skip_if_no("numba", "0.46.0")
+@pytest.mark.filterwarnings("ignore:\\nThe keyword argument")
+# Filter warnings when parallel=True and the function can't be parallelized by Numba
+@pytest.mark.parametrize("jit", [True, False])
+@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
+def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython):
+ def func_numba(values, index):
+ return np.mean(values) * 2.7
+
+ if jit:
+ # Test accepted jitted functions
+ import numba
+
+ func_numba = numba.jit(func_numba)
+
+ data = DataFrame(
+ {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1],
+ )
+ engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
+ grouped = data.groupby(0)
+ if pandas_obj == "Series":
+ grouped = grouped[1]
+
+ result = grouped.agg(func_numba, engine="numba", engine_kwargs=engine_kwargs)
+ expected = grouped.agg(lambda x: np.mean(x) * 2.7, engine="cython")
+
+ tm.assert_equal(result, expected)
+
+
+@td.skip_if_no("numba", "0.46.0")
+@pytest.mark.filterwarnings("ignore:\\nThe keyword argument")
+# Filter warnings when parallel=True and the function can't be parallelized by Numba
+@pytest.mark.parametrize("jit", [True, False])
+@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
+def test_cache(jit, pandas_obj, nogil, parallel, nopython):
+ # Test that the functions are cached correctly if we switch functions
+ def func_1(values, index):
+ return np.mean(values) - 3.4
+
+ def func_2(values, index):
+ return np.mean(values) * 2.7
+
+ if jit:
+ import numba
+
+ func_1 = numba.jit(func_1)
+ func_2 = numba.jit(func_2)
+
+ data = DataFrame(
+ {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1],
+ )
+ engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
+ grouped = data.groupby(0)
+ if pandas_obj == "Series":
+ grouped = grouped[1]
+
+ result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs)
+ expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython")
+ tm.assert_equal(result, expected)
+ # func_1 should be in the cache now
+ assert (func_1, "groupby_agg") in NUMBA_FUNC_CACHE
+
+ # Add func_2 to the cache
+ result = grouped.agg(func_2, engine="numba", engine_kwargs=engine_kwargs)
+ expected = grouped.agg(lambda x: np.mean(x) * 2.7, engine="cython")
+ tm.assert_equal(result, expected)
+ assert (func_2, "groupby_agg") in NUMBA_FUNC_CACHE
+
+ # Retest func_1 which should use the cache
+ result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs)
+ expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython")
+ tm.assert_equal(result, expected)