Skip to content

BUG/REF: Use lru_cache instead of NUMBA_FUNC_CACHE #46086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 27, 2022
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ Groupby/resample/rolling
- Bug in :meth:`DataFrame.resample` ignoring ``closed="right"`` on :class:`TimedeltaIndex` (:issue:`45414`)
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``func="size"`` and the input DataFrame has multiple columns (:issue:`27469`)
- Bug in :meth:`.DataFrameGroupBy.size` and :meth:`.DataFrameGroupBy.transform` with ``func="size"`` produced incorrect results when ``axis=1`` (:issue:`45715`)
- Bug in :meth:`.ExponentialMovingWindow.mean` with ``axis=1`` and ``engine='numba'`` when the :class:`DataFrame` has more columns than rows (:issue:`46086`)
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
- Bug in :meth:`.DataFrameGroupby.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)

Reshaping
Expand Down
29 changes: 11 additions & 18 deletions pandas/core/_numba/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
from typing import (
TYPE_CHECKING,
Callable,
Expand All @@ -10,16 +11,13 @@
from pandas._typing import Scalar
from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
get_jit_arguments,
)


@functools.lru_cache(maxsize=None)
def generate_shared_aggregator(
func: Callable[..., Scalar],
engine_kwargs: dict[str, bool] | None,
cache_key_str: str,
nopython: bool,
nogil: bool,
parallel: bool,
):
"""
Generate a Numba function that loops over the columns 2D object and applies
Expand All @@ -29,22 +27,17 @@ def generate_shared_aggregator(
----------
func : function
aggregation function to be applied to each column
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit
cache_key_str: str
string to access the compiled function of the form
<caller_type>_<aggregation_type> e.g. rolling_mean, groupby_mean
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit

Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, None)

cache_key = (func, cache_key_str)
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]

if TYPE_CHECKING:
import numba
else:
Expand Down
60 changes: 17 additions & 43 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class providing the base-class of operations.
from pandas.core.series import Series
from pandas.core.sorting import get_group_index_sorter
from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
get_jit_arguments,
maybe_use_numba,
)

Expand Down Expand Up @@ -1247,11 +1247,7 @@ def _resolve_numeric_only(self, numeric_only: bool | lib.NoDefault) -> bool:
# numba

@final
def _numba_prep(self, func, data):
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
def _numba_prep(self, data):
ids, _, ngroups = self.grouper.group_info
sorted_index = get_group_index_sorter(ids, ngroups)
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)
Expand All @@ -1271,7 +1267,6 @@ def _numba_agg_general(
self,
func: Callable,
engine_kwargs: dict[str, bool] | None,
numba_cache_key_str: str,
*aggregator_args,
):
"""
Expand All @@ -1288,16 +1283,12 @@ def _numba_agg_general(
with self._group_selection_context():
data = self._selected_obj
df = data if data.ndim == 2 else data.to_frame()
starts, ends, sorted_index, sorted_data = self._numba_prep(func, df)
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
aggregator = executor.generate_shared_aggregator(
func, engine_kwargs, numba_cache_key_str
func, **get_jit_arguments(engine_kwargs)
)
result = aggregator(sorted_data, starts, ends, 0, *aggregator_args)

cache_key = (func, numba_cache_key_str)
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = aggregator

index = self.grouper.result_index
if data.ndim == 1:
result_kwargs = {"name": data.name}
Expand All @@ -1315,10 +1306,10 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
to generate the indices of each group in the sorted data and then passes the
data and indices into a Numba jitted function.
"""
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)

starts, ends, sorted_index, sorted_data = self._numba_prep(data)
numba_.validate_udf(func)
numba_transform_func = numba_.generate_numba_transform_func(
kwargs, func, engine_kwargs
func, **get_jit_arguments(engine_kwargs, kwargs)
)
result = numba_transform_func(
sorted_data,
Expand All @@ -1328,11 +1319,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
len(data.columns),
*args,
)

cache_key = (func, "groupby_transform")
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func

# result values needs to be resorted to their original positions since we
# evaluated the data sorted by group
return result.take(np.argsort(sorted_index), axis=0)
Expand All @@ -1346,9 +1332,11 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
to generate the indices of each group in the sorted data and then passes the
data and indices into a Numba jitted function.
"""
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)

numba_agg_func = numba_.generate_numba_agg_func(kwargs, func, engine_kwargs)
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
numba_.validate_udf(func)
numba_agg_func = numba_.generate_numba_agg_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
)
result = numba_agg_func(
sorted_data,
sorted_index,
Expand All @@ -1357,11 +1345,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
len(data.columns),
*args,
)

cache_key = (func, "groupby_agg")
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func

return result

# -----------------------------------------------------------------
Expand Down Expand Up @@ -1947,7 +1930,7 @@ def mean(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_mean

return self._numba_agg_general(sliding_mean, engine_kwargs, "groupby_mean")
return self._numba_agg_general(sliding_mean, engine_kwargs)
else:
result = self._cython_agg_general(
"mean",
Expand Down Expand Up @@ -2029,9 +2012,7 @@ def std(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_var

return np.sqrt(
self._numba_agg_general(sliding_var, engine_kwargs, "groupby_std", ddof)
)
return np.sqrt(self._numba_agg_general(sliding_var, engine_kwargs, ddof))
else:
return self._get_cythonized_result(
libgroupby.group_var,
Expand Down Expand Up @@ -2085,9 +2066,7 @@ def var(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_var

return self._numba_agg_general(
sliding_var, engine_kwargs, "groupby_var", ddof
)
return self._numba_agg_general(sliding_var, engine_kwargs, ddof)
else:
if ddof == 1:
numeric_only = self._resolve_numeric_only(lib.no_default)
Expand Down Expand Up @@ -2180,7 +2159,6 @@ def sum(
return self._numba_agg_general(
sliding_sum,
engine_kwargs,
"groupby_sum",
)
else:
numeric_only = self._resolve_numeric_only(numeric_only)
Expand Down Expand Up @@ -2221,9 +2199,7 @@ def min(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_min_max

return self._numba_agg_general(
sliding_min_max, engine_kwargs, "groupby_min", False
)
return self._numba_agg_general(sliding_min_max, engine_kwargs, False)
else:
return self._agg_general(
numeric_only=numeric_only,
Expand All @@ -2244,9 +2220,7 @@ def max(
if maybe_use_numba(engine):
from pandas.core._numba.kernels import sliding_min_max

return self._numba_agg_general(
sliding_min_max, engine_kwargs, "groupby_max", True
)
return self._numba_agg_general(sliding_min_max, engine_kwargs, True)
else:
return self._agg_general(
numeric_only=numeric_only,
Expand Down
55 changes: 26 additions & 29 deletions pandas/core/groupby/numba_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Common utilities for Numba operations with groupby ops"""
from __future__ import annotations

import functools
import inspect
from typing import (
TYPE_CHECKING,
Expand All @@ -14,9 +15,7 @@
from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
NumbaUtilError,
get_jit_arguments,
jit_user_function,
)

Expand All @@ -43,6 +42,10 @@ def f(values, index, ...):
------
NumbaUtilError
"""
if not callable(func):
raise NotImplementedError(
"Numba engine can only be used with a single function."
)
udf_signature = list(inspect.signature(func).parameters.keys())
expected_args = ["values", "index"]
min_number_args = len(expected_args)
Expand All @@ -56,10 +59,12 @@ def f(values, index, ...):
)


@functools.lru_cache(maxsize=None)
def generate_numba_agg_func(
kwargs: dict[str, Any],
func: Callable[..., Scalar],
engine_kwargs: dict[str, bool] | None,
nopython: bool,
nogil: bool,
parallel: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
"""
Generate a numba jitted agg function specified by values from engine_kwargs.
Expand All @@ -72,24 +77,19 @@ def generate_numba_agg_func(

Parameters
----------
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit
function to be applied to each group and will be JITed
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit

Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)

validate_udf(func)
cache_key = (func, "groupby_agg")
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]

numba_func = jit_user_function(func, nopython, nogil, parallel)
if TYPE_CHECKING:
import numba
Expand Down Expand Up @@ -120,10 +120,12 @@ def group_agg(
return group_agg


@functools.lru_cache(maxsize=None)
def generate_numba_transform_func(
kwargs: dict[str, Any],
func: Callable[..., np.ndarray],
engine_kwargs: dict[str, bool] | None,
nopython: bool,
nogil: bool,
parallel: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
"""
Generate a numba jitted transform function specified by values from engine_kwargs.
Expand All @@ -136,24 +138,19 @@ def generate_numba_transform_func(

Parameters
----------
kwargs : dict
**kwargs to be passed into the function
func : function
function to be applied to each window and will be JITed
engine_kwargs : dict
dictionary of arguments to be passed into numba.jit
nopython : bool
nopython to be passed into numba.jit
nogil : bool
nogil to be passed into numba.jit
parallel : bool
parallel to be passed into numba.jit

Returns
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)

validate_udf(func)
cache_key = (func, "groupby_transform")
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]

numba_func = jit_user_function(func, nopython, nogil, parallel)
if TYPE_CHECKING:
import numba
Expand Down
7 changes: 3 additions & 4 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pandas.errors import NumbaUtilError

GLOBAL_USE_NUMBA: bool = False
NUMBA_FUNC_CACHE: dict[tuple[Callable, str], Callable] = {}


def maybe_use_numba(engine: str | None) -> bool:
Expand All @@ -30,7 +29,7 @@ def set_use_numba(enable: bool = False) -> None:

def get_jit_arguments(
engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
) -> tuple[bool, bool, bool]:
) -> dict[str, bool]:
"""
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.

Expand All @@ -43,7 +42,7 @@ def get_jit_arguments(

Returns
-------
(bool, bool, bool)
dict[str, bool]
nopython, nogil, parallel

Raises
Expand All @@ -61,7 +60,7 @@ def get_jit_arguments(
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
return nopython, nogil, parallel
return {"nopython": nopython, "nogil": nogil, "parallel": parallel}


def jit_user_function(
Expand Down
Loading