Skip to content

WIP: ENH: Add numba engine to groupby apply #35445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
42 changes: 38 additions & 4 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class providing the base-class of operations.

from pandas._config.config import option_context

from pandas._libs import Timestamp
from pandas._libs import Timestamp, lib
import pandas._libs.groupby as libgroupby
from pandas._typing import F, FrameOrSeries, FrameOrSeriesUnion, Scalar
from pandas.compat.numpy import function as nv
Expand All @@ -61,11 +61,11 @@ class providing the base-class of operations.
import pandas.core.common as com
from pandas.core.frame import DataFrame
from pandas.core.generic import NDFrame
from pandas.core.groupby import base, ops
from pandas.core.groupby import base, numba_, ops
from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex
from pandas.core.series import Series
from pandas.core.sorting import get_group_index_sorter
from pandas.core.util.numba_ import maybe_use_numba
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE, maybe_use_numba

_common_see_also = """
See Also
Expand Down Expand Up @@ -827,7 +827,12 @@ def __iter__(self):
input="dataframe", examples=_apply_docs["dataframe_examples"]
)
)
def apply(self, func, *args, **kwargs):
def apply(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
return self._apply_with_numba(
func, *args, engine_kwargs=engine_kwargs, **kwargs
)

func = self._is_builtin_func(func)

Expand Down Expand Up @@ -871,6 +876,35 @@ def f(g):

return result

def _apply_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
group_keys = self.grouper._get_group_keys()

with _group_selection_context(self):
# We always drop the column with the groupby key
data = self._selected_obj
labels, _, n_groups = self.grouper.group_info
sorted_index = get_group_index_sorter(labels, n_groups)
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
sorted_data = data.take(sorted_index, axis=self.axis)
starts, ends = lib.generate_slices(sorted_labels, n_groups)
cache_key = (func, "groupby_apply")
if cache_key in NUMBA_FUNC_CACHE:
# Return an already compiled version of roll_apply if available
apply_func = NUMBA_FUNC_CACHE[cache_key]
else:
apply_func = numba_.generate_numba_apply_func(
tuple(args), kwargs, func, engine_kwargs
)
result = apply_func(
sorted_data.to_numpy(), starts, ends, len(group_keys), len(data.columns)
)

if self.grouper.nkeys > 1:
index = MultiIndex.from_tuples(group_keys, names=self.grouper.names)
else:
index = Index(group_keys, name=self.grouper.names[0])
return self.obj._constructor(result, index=index, columns=data.columns)

def _python_apply_general(
self, f: F, data: FrameOrSeriesUnion
) -> FrameOrSeriesUnion:
Expand Down
73 changes: 73 additions & 0 deletions pandas/core/groupby/numba_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Callable, Dict, Optional, Tuple

import numpy as np

from pandas._typing import Scalar
from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import (
check_kwargs_and_nopython,
get_jit_arguments,
jit_user_function,
)


def generate_numba_apply_func(
args: Tuple,
kwargs: Dict[str, Any],
func: Callable[..., Scalar],
engine_kwargs: Optional[Dict[str, bool]],
):
"""
Generate a numba jitted apply function specified by values from engine_kwargs.

1. jit the user's function
2. Return a rolling apply function with the jitted function inline

Configurations specified in engine_kwargs apply to both the user's
function _AND_ the rolling apply function.

Parameters
----------
args : tuple
*args to be passed into the function
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit

Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)

check_kwargs_and_nopython(kwargs, nopython)

numba_func = jit_user_function(func, nopython, nogil, parallel)

numba = import_optional_dependency("numba")

if parallel:
loop_range = numba.prange
else:
loop_range = range

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def group_apply(
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
num_groups: int,
num_columns: int,
) -> np.ndarray:
result = np.empty((num_groups, num_columns))
for i in loop_range(num_groups):
for j in loop_range(num_columns):
group = values[begin[i] : end[i], j]
result[i, j] = numba_func(group, *args)
return result

return group_apply