diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index fb935c9065b83..c90da5c4001f2 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -223,8 +223,10 @@ def _selection_name(self): input="series", examples=_apply_docs["series_examples"] ) ) - def apply(self, func, *args, **kwargs): - return super().apply(func, *args, **kwargs) + def apply(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): + return super().apply( + func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs + ) @Substitution( see_also=_agg_see_also_doc, diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 6362f11a3e032..f17bb03f4cffa 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -58,7 +58,7 @@ class providing the base-class of operations. import pandas.core.common as com from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame -from pandas.core.groupby import base, ops +from pandas.core.groupby import base, ops, numba_ from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex from pandas.core.series import Series from pandas.core.sorting import get_group_index_sorter @@ -703,36 +703,50 @@ def __iter__(self): input="dataframe", examples=_apply_docs["dataframe_examples"] ) ) - def apply(self, func, *args, **kwargs): + def apply(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): func = self._is_builtin_func(func) - # this is needed so we don't try and wrap strings. If we could - # resolve functions to their callable functions prior, this - # wouldn't be needed - if args or kwargs: - if callable(func): - - @wraps(func) - def f(g): - with np.errstate(all="ignore"): - return func(g, *args, **kwargs) + if engine == "cython": + # this is needed so we don't try and wrap strings. If we could + # resolve functions to their callable functions prior, this + # wouldn't be needed + if args or kwargs: + if callable(func): + + @wraps(func) + def f(g): + with np.errstate(all="ignore"): + return func(g, *args, **kwargs) + + elif hasattr(nanops, "nan" + func): + # TODO: should we wrap this in to e.g. _is_builtin_func? + f = getattr(nanops, "nan" + func) + + else: + raise ValueError( + "func must be a callable if args or kwargs are supplied" + ) + else: + f = func + elif engine == "numba": - elif hasattr(nanops, "nan" + func): - # TODO: should we wrap this in to e.g. _is_builtin_func? - f = getattr(nanops, "nan" + func) + numba_.validate_apply_function_signature(func) + if func in self.grouper._numba_apply_cache: + # Return an already compiled version of the function if available + # TODO: this cache needs to be populated + f = self.grouper._numba_apply_cache[func] else: - raise ValueError( - "func must be a callable if args or kwargs are supplied" - ) + # TODO: support args + f = numba_.generate_numba_apply_func(args, kwargs, func, engine_kwargs) else: - f = func + raise ValueError("engine must be either 'numba' or 'cython'") # ignore SettingWithCopy here in case the user mutates with option_context("mode.chained_assignment", None): try: - result = self._python_apply_general(f) + result = self._python_apply_general(f, engine) except TypeError: # gh-20949 # try again, with .apply acting as a filtering @@ -743,12 +757,14 @@ def f(g): # on a string grouper column with _group_selection_context(self): - return self._python_apply_general(f) + return self._python_apply_general(f, engine) return result - def _python_apply_general(self, f): - keys, values, mutated = self.grouper.apply(f, self._selected_obj, self.axis) + def _python_apply_general(self, f, engine="cython"): + keys, values, mutated = self.grouper.apply( + f, self._selected_obj, self.axis, engine=engine + ) return self._wrap_applied_output( keys, values, not_indexed_same=mutated or self.mutated diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py new file mode 100644 index 0000000000000..2f71a46775c94 --- /dev/null +++ b/pandas/core/groupby/numba_.py @@ -0,0 +1,138 @@ +import inspect +import types + +import numpy as np + +from pandas.compat._optional import import_optional_dependency + + +class InvalidApply(Exception): + pass + + +def execute_groupby_function(splitter, f): + """Mimics apply_frame_axis0 which is the Cython equivalent of this function.""" + results = [] + for _, group in splitter: + # TODO: what about series names/dataframe columns + index = group.index + values_as_array = group.to_numpy() + index_as_array = index.to_numpy() + try: + # TODO: support *args, **kwargs here + group_result = f(values_as_array, index_as_array) + except Exception: + # We can't be more specific without knowing something about `f` + # Like we do in Cython + raise InvalidApply("Let this error raise above us") + # Reconstruct the pandas object (expected downstream) + # This construction will fail is there is mutation, + # but we're banning it with numba? + group_result = group._constructor(group_result, index=index) + results.append(group_result) + + return results + + +def validate_apply_function_signature(func): + """ + Validate that the apply function's first 2 arguments are 'values' and 'index'. + + func : function + function to be applied to each group and will be JITed + """ + apply_function_signature = list(inspect.signature(func).parameters.keys())[:2] + if apply_function_signature != ["values", "index"]: + raise ValueError( + "The apply function's first 2 arguments must be 'values' and 'index'" + ) + + +def make_groupby_apply( + func, args, nogil, parallel, nopython, +): + """ + Creates a JITted groupby apply function with a JITted version of + the user's function. + + Parameters + ---------- + func : function + function to be applied to each group and will be JITed + args : tuple + *args to be passed into the function + nogil : bool + nogil parameter from engine_kwargs for numba.jit + parallel : bool + parallel parameter from engine_kwargs for numba.jit + nopython : bool + nopython parameter from engine_kwargs for numba.jit + + Returns + ------- + Numba function + """ + numba = import_optional_dependency("numba") + + if isinstance(func, numba.targets.registry.CPUDispatcher): + # Don't jit a user passed jitted function + numba_func = func + else: + + @numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel) + def numba_func(group, *_args): + if getattr(np, func.__name__, False) is func or isinstance( + func, types.BuiltinFunctionType + ): + jf = func + else: + jf = numba.jit(func, nopython=nopython, nogil=nogil) + + def impl(group, *_args): + return jf(group, *_args) + + return impl + + return numba_func + + +def generate_numba_apply_func( + args, kwargs, func, engine_kwargs, +): + """ + Generate a numba jitted apply function specified by values from engine_kwargs. + + 1. jit the user's function + + Configurations specified in engine_kwargs apply to both the user's + function _AND_ the rolling apply function. + + Parameters + ---------- + args : tuple + *args to be passed into the function + kwargs : dict + **kwargs to be passed into the function + func : function + function to be applied to each group and will be JITed + engine_kwargs : dict + dictionary of arguments to be passed into numba.jit + + Returns + ------- + Numba function + """ + if engine_kwargs is None: + engine_kwargs = {} + + nopython = engine_kwargs.get("nopython", True) + nogil = engine_kwargs.get("nogil", False) + parallel = engine_kwargs.get("parallel", False) + + if kwargs and nopython: + raise ValueError( + "numba does not support kwargs with nopython=True: " + "https://github.com/numba/numba/issues/2916" + ) + + return make_groupby_apply(func, args, nogil, parallel, nopython) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 7259268ac3f2b..2d5755c4470d8 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -43,7 +43,7 @@ import pandas.core.common as com from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame -from pandas.core.groupby import base, grouper +from pandas.core.groupby import base, grouper, numba_ from pandas.core.indexes.api import Index, MultiIndex, ensure_index from pandas.core.series import Series from pandas.core.sorting import ( @@ -96,6 +96,7 @@ def __init__( self.group_keys = group_keys self.mutated = mutated self.indexer = indexer + self._numba_apply_cache = dict() @property def groupings(self) -> List["grouper.Grouping"]: @@ -148,13 +149,23 @@ def _get_group_keys(self): # provide "flattened" iterator for multi-group setting return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes) - def apply(self, f, data: FrameOrSeries, axis: int = 0): + def apply(self, f, data: FrameOrSeries, axis: int = 0, engine="cython"): mutated = self.mutated splitter = self._get_splitter(data, axis=axis) group_keys = self._get_group_keys() result_values = None sdata: FrameOrSeries = splitter._get_sorted_data() + + if engine == "numba": + result_values = numba_.execute_groupby_function(splitter, f) + + # mutation is determined based on index alignment + # numba functions always return numpy arrays w/o indexes + # therefore, mutated=False? + # or just ban mutation so mutated=False always + return group_keys, result_values, False + if sdata.ndim == 2 and np.any(sdata.dtypes.apply(is_extension_array_dtype)): # calling splitter.fast_apply will raise TypeError via apply_frame_axis0 # if we pass EA instead of ndarray