Skip to content

ENH: Add numba engine to groupby.aggregate #33388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
194bf7f
Add engine keywords to aggregate signature
Apr 6, 2020
7d42379
ENH: Add numba engine to groupby.transform
Apr 6, 2020
2124b81
include numba jitted func in agg routine
Apr 8, 2020
2321693
Merge remote-tracking branch 'upstream/master' into groupby_agg_numba
Apr 13, 2020
0f8a692
Add util functions
Apr 13, 2020
1d09ce1
Add cache and more routines
Apr 13, 2020
38e4485
Merge remote-tracking branch 'upstream/master' into groupby_agg_numba
Apr 17, 2020
79ee638
Merge remote-tracking branch 'upstream/master' into groupby_agg_numba
Apr 17, 2020
b43f183
minimize whitespace diff
Apr 17, 2020
632fb0c
Merge remote-tracking branch 'upstream/master' into groupby_agg_numba
Apr 17, 2020
f30ba2b
fix split by numba call
Apr 17, 2020
a8b7fdd
Merge remote-tracking branch 'upstream/master' into groupby_agg_numba
Apr 20, 2020
dadba23
Merge remote-tracking branch 'upstream/master' into groupby_agg_numba
Apr 21, 2020
6e4cdd1
Use global cache correctly
Apr 21, 2020
7ffe304
Raise for numba specific errors, add tests
Apr 21, 2020
9fc1068
Add benchmarks for new engine
Apr 21, 2020
599a640
Merge remote-tracking branch 'upstream/master' into groupby_agg_numba
Apr 21, 2020
4d1cbd5
Add whatsnew entry
Apr 21, 2020
7554190
Fix benchmarks and lint
Apr 21, 2020
0729230
Reorder function arguments
Apr 22, 2020
4092882
Merge remote-tracking branch 'upstream/master' into groupby_agg_numba
Apr 22, 2020
42b5171
Merge remote-tracking branch 'upstream/master' into groupby_agg_numba
Apr 24, 2020
7a9055c
Add documentation about groupby functions with numba access
Apr 24, 2020
3004046
Add warning about no fall back behavior
Apr 24, 2020
123e53a
Add noqa to timeit
Apr 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,4 +660,62 @@ def function(values):
self.grouper.transform(function, engine="cython")


class AggEngine:
def setup(self):
N = 10 ** 3
data = DataFrame(
{0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N},
columns=[0, 1],
)
self.grouper = data.groupby(0)

def time_series_numba(self):
def function(values, index):
total = 0
for i, value in enumerate(values):
if i % 2:
total += value + 5
else:
total += value * 2
return total

self.grouper[1].agg(function, engine="numba")

def time_series_cython(self):
def function(values):
total = 0
for i, value in enumerate(values):
if i % 2:
total += value + 5
else:
total += value * 2
return total

self.grouper[1].agg(function, engine="cython")

def time_dataframe_numba(self):
def function(values, index):
total = 0
for i, value in enumerate(values):
if i % 2:
total += value + 5
else:
total += value * 2
return total

self.grouper.agg(function, engine="numba")

def time_dataframe_cython(self):
def function(values):
total = 0
for i, value in enumerate(values):
if i % 2:
total += value + 5
else:
total += value * 2
return total

self.grouper.agg(function, engine="cython")


from .pandas_vb_common import setup # noqa: F401 isort:skip
4 changes: 2 additions & 2 deletions doc/source/user_guide/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ and their default values are set to ``False``, ``True`` and ``False`` respective
.. note::

In terms of performance, **the first time a function is run using the Numba engine will be slow**
as Numba will have some function compilation overhead. However, ``rolling`` objects will cache
the function and subsequent calls will be fast. In general, the Numba engine is performant with
as Numba will have some function compilation overhead. However, the compiled functions are cached,
and subsequent calls will be fast. In general, the Numba engine is performant with
a larger amount of data points (e.g. 1+ million).

.. code-block:: ipython
Expand Down
67 changes: 67 additions & 0 deletions doc/source/user_guide/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,73 @@ that is itself a series, and possibly upcast the result to a DataFrame:
the output as well as set the indices.


Numba Accelerated Routines
--------------------------

.. versionadded:: 1.1

If `Numba <https://numba.pydata.org/>`__ is installed as an optional dependency, the ``transform`` and
``aggregate`` methods support ``engine='numba'`` and ``engine_kwargs`` arguments. The ``engine_kwargs``
argument is a dictionary of keyword arguments that will be passed into the
`numba.jit decorator <https://numba.pydata.org/numba-doc/latest/reference/jit-compilation.html#numba.jit>`__.
These keyword arguments will be applied to the passed function. Currently only ``nogil``, ``nopython``,
and ``parallel`` are supported, and their default values are set to ``False``, ``True`` and ``False`` respectively.

The function signature must start with ``values, index`` **exactly** as the data belonging to each group
will be passed into ``values``, and the group index will be passed into ``index``.

.. warning::

When using ``engine='numba'``, there will be no "fall back" behavior internally. The group
data and group index will be passed as numpy arrays to the JITed user defined function, and no
alternative execution attempts will be tried.

.. note::

In terms of performance, **the first time a function is run using the Numba engine will be slow**
as Numba will have some function compilation overhead. However, the compiled functions are cached,
and subsequent calls will be fast. In general, the Numba engine is performant with
a larger amount of data points (e.g. 1+ million).

.. code-block:: ipython

In [1]: N = 10 ** 3

In [2]: data = {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N}

In [3]: df = pd.DataFrame(data, columns=[0, 1])

In [4]: def f_numba(values, index):
...: total = 0
...: for i, value in enumerate(values):
...: if i % 2:
...: total += value + 5
...: else:
...: total += value * 2
...: return total
...:

In [5]: def f_cython(values):
...: total = 0
...: for i, value in enumerate(values):
...: if i % 2:
...: total += value + 5
...: else:
...: total += value * 2
...: return total
...:

In [6]: groupby = df.groupby(0)
# Run the first time, compilation time will affect performance
In [7]: %timeit -r 1 -n 1 groupby.aggregate(f_numba, engine='numba') # noqa: E225
2.14 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
# Function is cached and performance will improve
In [8]: %timeit groupby.aggregate(f_numba, engine='numba')
4.93 ms ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [9]: %timeit groupby.aggregate(f_cython, engine='cython')
18.6 ms ± 84.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Other useful features
---------------------

Expand Down
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Other enhancements
This can be used to set a custom compression level, e.g.,
``df.to_csv(path, compression={'method': 'gzip', 'compresslevel': 1}``
(:issue:`33196`)
- :meth:`~pandas.core.groupby.GroupBy.transform` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`)
- :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`)
- :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`)
-

Expand Down
27 changes: 22 additions & 5 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
NUMBA_FUNC_CACHE,
check_kwargs_and_nopython,
get_jit_arguments,
is_numba_util_related_error,
jit_user_function,
split_for_numba,
validate_udf,
Expand Down Expand Up @@ -244,7 +245,9 @@ def apply(self, func, *args, **kwargs):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, func=None, *args, **kwargs):
def aggregate(
self, func=None, *args, engine="cython", engine_kwargs=None, **kwargs
):

relabeling = func is None
columns = None
Expand Down Expand Up @@ -272,11 +275,18 @@ def aggregate(self, func=None, *args, **kwargs):
return getattr(self, cyfunc)()

if self.grouper.nkeys > 1:
return self._python_agg_general(func, *args, **kwargs)
return self._python_agg_general(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

try:
return self._python_agg_general(func, *args, **kwargs)
except (ValueError, KeyError):
return self._python_agg_general(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
except (ValueError, KeyError) as err:
# Do not catch Numba errors here, we want to raise and not fall back.
if is_numba_util_related_error(str(err)):
raise err
# TODO: KeyError is raised in _python_agg_general,
# see see test_groupby.test_basic
result = self._aggregate_named(func, *args, **kwargs)
Expand Down Expand Up @@ -941,7 +951,9 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
axis="",
)
@Appender(_shared_docs["aggregate"])
def aggregate(self, func=None, *args, **kwargs):
def aggregate(
self, func=None, *args, engine="cython", engine_kwargs=None, **kwargs
):

relabeling = func is None and is_multi_agg_with_relabel(**kwargs)
if relabeling:
Expand All @@ -962,6 +974,11 @@ def aggregate(self, func=None, *args, **kwargs):

func = maybe_mangle_lambdas(func)

if engine == "numba":
return self._python_agg_general(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

result, how = self._aggregate(func, *args, **kwargs)
if how is None:
return result
Expand Down
27 changes: 20 additions & 7 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,9 +910,12 @@ def _cython_agg_general(

return self._wrap_aggregated_output(output)

def _python_agg_general(self, func, *args, **kwargs):
def _python_agg_general(
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
):
func = self._is_builtin_func(func)
f = lambda x: func(x, *args, **kwargs)
if engine != "numba":
f = lambda x: func(x, *args, **kwargs)

# iterate through "columns" ex exclusions to populate output dict
output: Dict[base.OutputKey, np.ndarray] = {}
Expand All @@ -923,11 +926,21 @@ def _python_agg_general(self, func, *args, **kwargs):
# agg_series below assumes ngroups > 0
continue

try:
# if this function is invalid for this dtype, we will ignore it.
result, counts = self.grouper.agg_series(obj, f)
except TypeError:
continue
if engine == "numba":
result, counts = self.grouper.agg_series(
obj,
func,
*args,
engine=engine,
engine_kwargs=engine_kwargs,
**kwargs,
)
else:
try:
# if this function is invalid for this dtype, we will ignore it.
result, counts = self.grouper.agg_series(obj, f)
except TypeError:
continue

assert result is not None
key = base.OutputKey(label=name, position=idx)
Expand Down
42 changes: 38 additions & 4 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@
get_group_index_sorter,
get_indexer_dict,
)
from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
check_kwargs_and_nopython,
get_jit_arguments,
jit_user_function,
split_for_numba,
validate_udf,
)


class BaseGrouper:
Expand Down Expand Up @@ -608,10 +616,16 @@ def _transform(

return result

def agg_series(self, obj: Series, func):
def agg_series(
self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
):
# Caller is responsible for checking ngroups != 0
assert self.ngroups != 0

if engine == "numba":
return self._aggregate_series_pure_python(
obj, func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
if len(obj) == 0:
# SeriesGrouper would raise if we were to call _aggregate_series_fast
return self._aggregate_series_pure_python(obj, func)
Expand Down Expand Up @@ -656,7 +670,18 @@ def _aggregate_series_fast(self, obj: Series, func):
result, counts = grouper.get_result()
return result, counts

def _aggregate_series_pure_python(self, obj: Series, func):
def _aggregate_series_pure_python(
self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
):

if engine == "numba":
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you condense these to a single function call (with whatever return args you need later on)
you can certainly leave these functions individually in core.util.numba_, just when you are calling it would make th api simpler here. (also if you can do this simplification other places we call numba).

can do this in a followup.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure can follow up with this cleanup

check_kwargs_and_nopython(kwargs, nopython)
validate_udf(func)
cache_key = (func, "groupby_agg")
numba_func = NUMBA_FUNC_CACHE.get(
cache_key, jit_user_function(func, nopython, nogil, parallel)
)

group_index, _, ngroups = self.group_info

Expand All @@ -666,7 +691,14 @@ def _aggregate_series_pure_python(self, obj: Series, func):
splitter = get_splitter(obj, group_index, ngroups, axis=0)

for label, group in splitter:
res = func(group)
if engine == "numba":
values, index = split_for_numba(group)
res = numba_func(values, index, *args)
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_func
else:
res = func(group, *args, **kwargs)

if result is None:
if isinstance(res, (Series, Index, np.ndarray)):
if len(res) == 1:
Expand Down Expand Up @@ -842,7 +874,9 @@ def groupings(self) -> "List[grouper.Grouping]":
for lvl, name in zip(self.levels, self.names)
]

def agg_series(self, obj: Series, func):
def agg_series(
self, obj: Series, func, *args, engine="cython", engine_kwargs=None, **kwargs
):
# Caller is responsible for checking ngroups != 0
assert self.ngroups != 0
assert len(self.bins) > 0 # otherwise we'd get IndexError in get_result
Expand Down
20 changes: 19 additions & 1 deletion pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@
NUMBA_FUNC_CACHE: Dict[Tuple[Callable, str], Callable] = dict()


def is_numba_util_related_error(err_message: str) -> bool:
"""
Check if an error was raised from one of the numba utility functions

For cases where a try/except block has mistakenly caught the error
and we want to re-raise

Parameters
----------
err_message : str,
exception error message

Returns
-------
bool
"""
return "The first" in err_message or "numba does not" in err_message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you actually just check the class type here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I raise ValueError from the numba utilities, so I need to distinguish between ValueErrors from another op vs the numba utilties.

Should I make a NumbaUtilError(ValueError) exception? Then we would also need to expose it to users

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, let's have a separate exception (but in a followon)



def check_kwargs_and_nopython(
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None
) -> None:
Expand Down Expand Up @@ -76,7 +95,6 @@ def jit_user_function(
----------
func : function
user defined function

nopython : bool
nopython parameter for numba.JIT
nogil : bool
Expand Down
Loading