diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index d4e673d2e538c..a931221ef3ce1 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -74,12 +74,11 @@ get_groupby, group_selection_context, ) -from pandas.core.groupby.numba_ import generate_numba_func, split_for_numba from pandas.core.indexes.api import Index, MultiIndex, all_indexes_same import pandas.core.indexes.base as ibase from pandas.core.internals import BlockManager from pandas.core.series import Series -from pandas.core.util.numba_ import NUMBA_FUNC_CACHE, maybe_use_numba +from pandas.core.util.numba_ import maybe_use_numba from pandas.plotting import boxplot_frame_groupby @@ -518,29 +517,16 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): result = getattr(self, func)(*args, **kwargs) return self._transform_fast(result) - def _transform_general( - self, func, *args, engine="cython", engine_kwargs=None, **kwargs - ): + def _transform_general(self, func, *args, **kwargs): """ Transform with a non-str `func`. """ - if maybe_use_numba(engine): - numba_func, cache_key = generate_numba_func( - func, engine_kwargs, kwargs, "groupby_transform" - ) - klass = type(self._selected_obj) results = [] for name, group in self: object.__setattr__(group, "name", name) - if maybe_use_numba(engine): - values, index = split_for_numba(group) - res = numba_func(values, index, *args) - if cache_key not in NUMBA_FUNC_CACHE: - NUMBA_FUNC_CACHE[cache_key] = numba_func - else: - res = func(group, *args, **kwargs) + res = func(group, *args, **kwargs) if isinstance(res, (ABCDataFrame, ABCSeries)): res = res._values diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index ceee78bfebe68..9a14323dd8c3a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1071,16 +1071,15 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False) sorted_data = data.take(sorted_index, axis=self.axis).to_numpy() starts, ends = lib.generate_slices(sorted_labels, n_groups) - cache_key = (func, "groupby_transform") - if cache_key in NUMBA_FUNC_CACHE: - numba_transform_func = NUMBA_FUNC_CACHE[cache_key] - else: - numba_transform_func = numba_.generate_numba_transform_func( - tuple(args), kwargs, func, engine_kwargs - ) + + numba_transform_func = numba_.generate_numba_transform_func( + tuple(args), kwargs, func, engine_kwargs + ) result = numba_transform_func( sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns) ) + + cache_key = (func, "groupby_transform") if cache_key not in NUMBA_FUNC_CACHE: NUMBA_FUNC_CACHE[cache_key] = numba_transform_func @@ -1106,16 +1105,15 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False) sorted_data = data.take(sorted_index, axis=self.axis).to_numpy() starts, ends = lib.generate_slices(sorted_labels, n_groups) - cache_key = (func, "groupby_agg") - if cache_key in NUMBA_FUNC_CACHE: - numba_agg_func = NUMBA_FUNC_CACHE[cache_key] - else: - numba_agg_func = numba_.generate_numba_agg_func( - tuple(args), kwargs, func, engine_kwargs - ) + + numba_agg_func = numba_.generate_numba_agg_func( + tuple(args), kwargs, func, engine_kwargs + ) result = numba_agg_func( sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns) ) + + cache_key = (func, "groupby_agg") if cache_key not in NUMBA_FUNC_CACHE: NUMBA_FUNC_CACHE[cache_key] = numba_agg_func diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index a2dfcd7bddd53..76f50f1387196 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -4,34 +4,17 @@ import numpy as np -from pandas._typing import FrameOrSeries, Scalar +from pandas._typing import Scalar from pandas.compat._optional import import_optional_dependency from pandas.core.util.numba_ import ( NUMBA_FUNC_CACHE, NumbaUtilError, - check_kwargs_and_nopython, get_jit_arguments, jit_user_function, ) -def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray]: - """ - Split pandas object into its components as numpy arrays for numba functions. - - Parameters - ---------- - arg : Series or DataFrame - - Returns - ------- - (ndarray, ndarray) - values, index - """ - return arg.to_numpy(), arg.index.to_numpy() - - def validate_udf(func: Callable) -> None: """ Validate user defined function for ops when using Numba with groupby ops. @@ -67,46 +50,6 @@ def f(values, index, ...): ) -def generate_numba_func( - func: Callable, - engine_kwargs: Optional[Dict[str, bool]], - kwargs: dict, - cache_key_str: str, -) -> Tuple[Callable, Tuple[Callable, str]]: - """ - Return a JITed function and cache key for the NUMBA_FUNC_CACHE - - This _may_ be specific to groupby (as it's only used there currently). - - Parameters - ---------- - func : function - user defined function - engine_kwargs : dict or None - numba.jit arguments - kwargs : dict - kwargs for func - cache_key_str : str - string representing the second part of the cache key tuple - - Returns - ------- - (JITed function, cache key) - - Raises - ------ - NumbaUtilError - """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - check_kwargs_and_nopython(kwargs, nopython) - validate_udf(func) - cache_key = (func, cache_key_str) - numba_func = NUMBA_FUNC_CACHE.get( - cache_key, jit_user_function(func, nopython, nogil, parallel) - ) - return numba_func, cache_key - - def generate_numba_agg_func( args: Tuple, kwargs: Dict[str, Any], @@ -120,7 +63,7 @@ def generate_numba_agg_func( 2. Return a groupby agg function with the jitted function inline Configurations specified in engine_kwargs apply to both the user's - function _AND_ the rolling apply function. + function _AND_ the groupby evaluation loop. Parameters ---------- @@ -137,16 +80,15 @@ def generate_numba_agg_func( ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - - check_kwargs_and_nopython(kwargs, nopython) + nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs) validate_udf(func) + cache_key = (func, "groupby_agg") + if cache_key in NUMBA_FUNC_CACHE: + return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") - if parallel: loop_range = numba.prange else: @@ -175,17 +117,17 @@ def group_agg( def generate_numba_transform_func( args: Tuple, kwargs: Dict[str, Any], - func: Callable[..., Scalar], + func: Callable[..., np.ndarray], engine_kwargs: Optional[Dict[str, bool]], ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]: """ Generate a numba jitted transform function specified by values from engine_kwargs. 1. jit the user's function - 2. Return a groupby agg function with the jitted function inline + 2. Return a groupby transform function with the jitted function inline Configurations specified in engine_kwargs apply to both the user's - function _AND_ the rolling apply function. + function _AND_ the groupby evaluation loop. Parameters ---------- @@ -202,16 +144,15 @@ def generate_numba_transform_func( ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) - - check_kwargs_and_nopython(kwargs, nopython) + nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs) validate_udf(func) + cache_key = (func, "groupby_transform") + if cache_key in NUMBA_FUNC_CACHE: + return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") - if parallel: loop_range = numba.prange else: diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index b951cd4f0cc2a..f06dd10d0e497 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -24,37 +24,8 @@ def set_use_numba(enable: bool = False) -> None: GLOBAL_USE_NUMBA = enable -def check_kwargs_and_nopython( - kwargs: Optional[Dict] = None, nopython: Optional[bool] = None -) -> None: - """ - Validate that **kwargs and nopython=True was passed - https://github.com/numba/numba/issues/2916 - - Parameters - ---------- - kwargs : dict, default None - user passed keyword arguments to pass into the JITed function - nopython : bool, default None - nopython parameter - - Returns - ------- - None - - Raises - ------ - NumbaUtilError - """ - if kwargs and nopython: - raise NumbaUtilError( - "numba does not support kwargs with nopython=True: " - "https://github.com/numba/numba/issues/2916" - ) - - def get_jit_arguments( - engine_kwargs: Optional[Dict[str, bool]] = None + engine_kwargs: Optional[Dict[str, bool]] = None, kwargs: Optional[Dict] = None, ) -> Tuple[bool, bool, bool]: """ Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. @@ -63,16 +34,27 @@ def get_jit_arguments( ---------- engine_kwargs : dict, default None user passed keyword arguments for numba.JIT + kwargs : dict, default None + user passed keyword arguments to pass into the JITed function Returns ------- (bool, bool, bool) nopython, nogil, parallel + + Raises + ------ + NumbaUtilError """ if engine_kwargs is None: engine_kwargs = {} nopython = engine_kwargs.get("nopython", True) + if kwargs and nopython: + raise NumbaUtilError( + "numba does not support kwargs with nopython=True: " + "https://github.com/numba/numba/issues/2916" + ) nogil = engine_kwargs.get("nogil", False) parallel = engine_kwargs.get("parallel", False) return nopython, nogil, parallel diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index aec294c3c84c2..c4858b6e5a4ab 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -6,7 +6,7 @@ from pandas.compat._optional import import_optional_dependency from pandas.core.util.numba_ import ( - check_kwargs_and_nopython, + NUMBA_FUNC_CACHE, get_jit_arguments, jit_user_function, ) @@ -42,14 +42,14 @@ def generate_numba_apply_func( ------- Numba function """ - nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs) - check_kwargs_and_nopython(kwargs, nopython) + cache_key = (func, "rolling_apply") + if cache_key in NUMBA_FUNC_CACHE: + return NUMBA_FUNC_CACHE[cache_key] numba_func = jit_user_function(func, nopython, nogil, parallel) - numba = import_optional_dependency("numba") - if parallel: loop_range = numba.prange else: diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 00fdf0813b027..21a7164411fb7 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -1374,14 +1374,7 @@ def apply( if maybe_use_numba(engine): if raw is False: raise ValueError("raw must be `True` when using the numba engine") - cache_key = (func, "rolling_apply") - if cache_key in NUMBA_FUNC_CACHE: - # Return an already compiled version of roll_apply if available - apply_func = NUMBA_FUNC_CACHE[cache_key] - else: - apply_func = generate_numba_apply_func( - args, kwargs, func, engine_kwargs - ) + apply_func = generate_numba_apply_func(args, kwargs, func, engine_kwargs) center = self.center elif engine in ("cython", None): if engine_kwargs is not None: @@ -1403,7 +1396,7 @@ def apply( center=center, floor=0, name=func, - use_numba_cache=engine == "numba", + use_numba_cache=maybe_use_numba(engine), raw=raw, original_func=func, args=args,