Skip to content

WIP: ENH: Add engine keyword argument to groupby.apply to leverage Numba #32428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ def _selection_name(self):
input="series", examples=_apply_docs["series_examples"]
)
)
def apply(self, func, *args, **kwargs):
return super().apply(func, *args, **kwargs)
def apply(self, func, engine="cython", engine_kwargs=None, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will want to update the doc-strings as well

return super().apply(
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs
)

@Substitution(
see_also=_agg_see_also_doc,
Expand Down
62 changes: 39 additions & 23 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class providing the base-class of operations.
import pandas.core.common as com
from pandas.core.frame import DataFrame
from pandas.core.generic import NDFrame
from pandas.core.groupby import base, ops
from pandas.core.groupby import base, ops, numba_
from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex
from pandas.core.series import Series
from pandas.core.sorting import get_group_index_sorter
Expand Down Expand Up @@ -703,36 +703,50 @@ def __iter__(self):
input="dataframe", examples=_apply_docs["dataframe_examples"]
)
)
def apply(self, func, *args, **kwargs):
def apply(self, func, engine="cython", engine_kwargs=None, *args, **kwargs):

func = self._is_builtin_func(func)

# this is needed so we don't try and wrap strings. If we could
# resolve functions to their callable functions prior, this
# wouldn't be needed
if args or kwargs:
if callable(func):

@wraps(func)
def f(g):
with np.errstate(all="ignore"):
return func(g, *args, **kwargs)
if engine == "cython":
# this is needed so we don't try and wrap strings. If we could
# resolve functions to their callable functions prior, this
# wouldn't be needed
if args or kwargs:
if callable(func):

@wraps(func)
def f(g):
with np.errstate(all="ignore"):
return func(g, *args, **kwargs)

elif hasattr(nanops, "nan" + func):
# TODO: should we wrap this in to e.g. _is_builtin_func?
f = getattr(nanops, "nan" + func)

else:
raise ValueError(
"func must be a callable if args or kwargs are supplied"
)
else:
f = func
elif engine == "numba":

elif hasattr(nanops, "nan" + func):
# TODO: should we wrap this in to e.g. _is_builtin_func?
f = getattr(nanops, "nan" + func)
numba_.validate_apply_function_signature(func)

if func in self.grouper._numba_apply_cache:
# Return an already compiled version of the function if available
# TODO: this cache needs to be populated
f = self.grouper._numba_apply_cache[func]
else:
raise ValueError(
"func must be a callable if args or kwargs are supplied"
)
# TODO: support args
f = numba_.generate_numba_apply_func(args, kwargs, func, engine_kwargs)
else:
f = func
raise ValueError("engine must be either 'numba' or 'cython'")

# ignore SettingWithCopy here in case the user mutates
with option_context("mode.chained_assignment", None):
try:
result = self._python_apply_general(f)
result = self._python_apply_general(f, engine)
except TypeError:
# gh-20949
# try again, with .apply acting as a filtering
Expand All @@ -743,12 +757,14 @@ def f(g):
# on a string grouper column

with _group_selection_context(self):
return self._python_apply_general(f)
return self._python_apply_general(f, engine)

return result

def _python_apply_general(self, f):
keys, values, mutated = self.grouper.apply(f, self._selected_obj, self.axis)
def _python_apply_general(self, f, engine="cython"):
keys, values, mutated = self.grouper.apply(
f, self._selected_obj, self.axis, engine=engine
)

return self._wrap_applied_output(
keys, values, not_indexed_same=mutated or self.mutated
Expand Down
138 changes: 138 additions & 0 deletions pandas/core/groupby/numba_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import inspect
import types

import numpy as np

from pandas.compat._optional import import_optional_dependency


class InvalidApply(Exception):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need a new exception?

Copy link
Member Author

@mroeschke mroeschke Mar 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure will want to import this from its original location in cython. Or we can move InvalidApply to pandas.errors

pass


def execute_groupby_function(splitter, f):
"""Mimics apply_frame_axis0 which is the Cython equivalent of this function."""
results = []
for _, group in splitter:
# TODO: what about series names/dataframe columns
index = group.index
values_as_array = group.to_numpy()
index_as_array = index.to_numpy()
try:
# TODO: support *args, **kwargs here
group_result = f(values_as_array, index_as_array)
except Exception:
# We can't be more specific without knowing something about `f`
# Like we do in Cython
raise InvalidApply("Let this error raise above us")
# Reconstruct the pandas object (expected downstream)
# This construction will fail is there is mutation,
# but we're banning it with numba?
group_result = group._constructor(group_result, index=index)
results.append(group_result)

return results


def validate_apply_function_signature(func):
"""
Validate that the apply function's first 2 arguments are 'values' and 'index'.

func : function
function to be applied to each group and will be JITed
"""
apply_function_signature = list(inspect.signature(func).parameters.keys())[:2]
if apply_function_signature != ["values", "index"]:
raise ValueError(
"The apply function's first 2 arguments must be 'values' and 'index'"
)


def make_groupby_apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we share pandas/core/window/numba_.py at all? maybe put common functions in pandas/core/algos/numba_.py (we have another issue about creating this and moving pandas/core/aggregation.py and algorithms.py to there

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, will consolidate in another commit.

func, args, nogil, parallel, nopython,
):
"""
Creates a JITted groupby apply function with a JITted version of
the user's function.

Parameters
----------
func : function
function to be applied to each group and will be JITed
args : tuple
*args to be passed into the function
nogil : bool
nogil parameter from engine_kwargs for numba.jit
parallel : bool
parallel parameter from engine_kwargs for numba.jit
nopython : bool
nopython parameter from engine_kwargs for numba.jit

Returns
-------
Numba function
"""
numba = import_optional_dependency("numba")

if isinstance(func, numba.targets.registry.CPUDispatcher):
# Don't jit a user passed jitted function
numba_func = func
else:

@numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
def numba_func(group, *_args):
if getattr(np, func.__name__, False) is func or isinstance(
func, types.BuiltinFunctionType
):
jf = func
else:
jf = numba.jit(func, nopython=nopython, nogil=nogil)

def impl(group, *_args):
return jf(group, *_args)

return impl

return numba_func


def generate_numba_apply_func(
args, kwargs, func, engine_kwargs,
):
"""
Generate a numba jitted apply function specified by values from engine_kwargs.

1. jit the user's function

Configurations specified in engine_kwargs apply to both the user's
function _AND_ the rolling apply function.

Parameters
----------
args : tuple
*args to be passed into the function
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each group and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit

Returns
-------
Numba function
"""
if engine_kwargs is None:
engine_kwargs = {}

nopython = engine_kwargs.get("nopython", True)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)

if kwargs and nopython:
raise ValueError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)

return make_groupby_apply(func, args, nogil, parallel, nopython)
15 changes: 13 additions & 2 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import pandas.core.common as com
from pandas.core.frame import DataFrame
from pandas.core.generic import NDFrame
from pandas.core.groupby import base, grouper
from pandas.core.groupby import base, grouper, numba_
from pandas.core.indexes.api import Index, MultiIndex, ensure_index
from pandas.core.series import Series
from pandas.core.sorting import (
Expand Down Expand Up @@ -96,6 +96,7 @@ def __init__(
self.group_keys = group_keys
self.mutated = mutated
self.indexer = indexer
self._numba_apply_cache = dict()

@property
def groupings(self) -> List["grouper.Grouping"]:
Expand Down Expand Up @@ -148,13 +149,23 @@ def _get_group_keys(self):
# provide "flattened" iterator for multi-group setting
return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes)

def apply(self, f, data: FrameOrSeries, axis: int = 0):
def apply(self, f, data: FrameOrSeries, axis: int = 0, engine="cython"):
mutated = self.mutated
splitter = self._get_splitter(data, axis=axis)
group_keys = self._get_group_keys()
result_values = None

sdata: FrameOrSeries = splitter._get_sorted_data()

if engine == "numba":
result_values = numba_.execute_groupby_function(splitter, f)

# mutation is determined based on index alignment
# numba functions always return numpy arrays w/o indexes
# therefore, mutated=False?
# or just ban mutation so mutated=False always
return group_keys, result_values, False

if sdata.ndim == 2 and np.any(sdata.dtypes.apply(is_extension_array_dtype)):
# calling splitter.fast_apply will raise TypeError via apply_frame_axis0
# if we pass EA instead of ndarray
Expand Down