Skip to content

CLN: Numba internal routines #36376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,11 @@
get_groupby,
group_selection_context,
)
from pandas.core.groupby.numba_ import generate_numba_func, split_for_numba
from pandas.core.indexes.api import Index, MultiIndex, all_indexes_same
import pandas.core.indexes.base as ibase
from pandas.core.internals import BlockManager
from pandas.core.series import Series
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE, maybe_use_numba
from pandas.core.util.numba_ import maybe_use_numba

from pandas.plotting import boxplot_frame_groupby

Expand Down Expand Up @@ -518,29 +517,16 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
result = getattr(self, func)(*args, **kwargs)
return self._transform_fast(result)

def _transform_general(
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
):
def _transform_general(self, func, *args, **kwargs):
"""
Transform with a non-str `func`.
"""
if maybe_use_numba(engine):
numba_func, cache_key = generate_numba_func(
func, engine_kwargs, kwargs, "groupby_transform"
)

klass = type(self._selected_obj)

results = []
for name, group in self:
object.__setattr__(group, "name", name)
if maybe_use_numba(engine):
values, index = split_for_numba(group)
res = numba_func(values, index, *args)
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_func
else:
res = func(group, *args, **kwargs)
res = func(group, *args, **kwargs)

if isinstance(res, (ABCDataFrame, ABCSeries)):
res = res._values
Expand Down
26 changes: 12 additions & 14 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,16 +1071,15 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
starts, ends = lib.generate_slices(sorted_labels, n_groups)
cache_key = (func, "groupby_transform")
if cache_key in NUMBA_FUNC_CACHE:
numba_transform_func = NUMBA_FUNC_CACHE[cache_key]
else:
numba_transform_func = numba_.generate_numba_transform_func(
tuple(args), kwargs, func, engine_kwargs
)

numba_transform_func = numba_.generate_numba_transform_func(
tuple(args), kwargs, func, engine_kwargs
)
result = numba_transform_func(
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns)
)

cache_key = (func, "groupby_transform")
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func

Expand All @@ -1106,16 +1105,15 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
starts, ends = lib.generate_slices(sorted_labels, n_groups)
cache_key = (func, "groupby_agg")
if cache_key in NUMBA_FUNC_CACHE:
numba_agg_func = NUMBA_FUNC_CACHE[cache_key]
else:
numba_agg_func = numba_.generate_numba_agg_func(
tuple(args), kwargs, func, engine_kwargs
)

numba_agg_func = numba_.generate_numba_agg_func(
tuple(args), kwargs, func, engine_kwargs
)
result = numba_agg_func(
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns)
)

cache_key = (func, "groupby_agg")
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func

Expand Down
85 changes: 13 additions & 72 deletions pandas/core/groupby/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,17 @@

import numpy as np

from pandas._typing import FrameOrSeries, Scalar
from pandas._typing import Scalar
from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import (
NUMBA_FUNC_CACHE,
NumbaUtilError,
check_kwargs_and_nopython,
get_jit_arguments,
jit_user_function,
)


def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray]:
"""
Split pandas object into its components as numpy arrays for numba functions.

Parameters
----------
arg : Series or DataFrame

Returns
-------
(ndarray, ndarray)
values, index
"""
return arg.to_numpy(), arg.index.to_numpy()


def validate_udf(func: Callable) -> None:
"""
Validate user defined function for ops when using Numba with groupby ops.
Expand Down Expand Up @@ -67,46 +50,6 @@ def f(values, index, ...):
)


def generate_numba_func(
func: Callable,
engine_kwargs: Optional[Dict[str, bool]],
kwargs: dict,
cache_key_str: str,
) -> Tuple[Callable, Tuple[Callable, str]]:
"""
Return a JITed function and cache key for the NUMBA_FUNC_CACHE

This _may_ be specific to groupby (as it's only used there currently).

Parameters
----------
func : function
user defined function
engine_kwargs : dict or None
numba.jit arguments
kwargs : dict
kwargs for func
cache_key_str : str
string representing the second part of the cache key tuple

Returns
-------
(JITed function, cache key)

Raises
------
NumbaUtilError
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
check_kwargs_and_nopython(kwargs, nopython)
validate_udf(func)
cache_key = (func, cache_key_str)
numba_func = NUMBA_FUNC_CACHE.get(
cache_key, jit_user_function(func, nopython, nogil, parallel)
)
return numba_func, cache_key


def generate_numba_agg_func(
args: Tuple,
kwargs: Dict[str, Any],
Expand All @@ -120,7 +63,7 @@ def generate_numba_agg_func(
2. Return a groupby agg function with the jitted function inline

Configurations specified in engine_kwargs apply to both the user's
function _AND_ the rolling apply function.
function _AND_ the groupby evaluation loop.

Parameters
----------
Expand All @@ -137,16 +80,15 @@ def generate_numba_agg_func(
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)

check_kwargs_and_nopython(kwargs, nopython)
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)

validate_udf(func)
cache_key = (func, "groupby_agg")
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]

numba_func = jit_user_function(func, nopython, nogil, parallel)

numba = import_optional_dependency("numba")

if parallel:
loop_range = numba.prange
else:
Expand Down Expand Up @@ -175,17 +117,17 @@ def group_agg(
def generate_numba_transform_func(
args: Tuple,
kwargs: Dict[str, Any],
func: Callable[..., Scalar],
func: Callable[..., np.ndarray],
engine_kwargs: Optional[Dict[str, bool]],
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
"""
Generate a numba jitted transform function specified by values from engine_kwargs.

1. jit the user's function
2. Return a groupby agg function with the jitted function inline
2. Return a groupby transform function with the jitted function inline

Configurations specified in engine_kwargs apply to both the user's
function _AND_ the rolling apply function.
function _AND_ the groupby evaluation loop.

Parameters
----------
Expand All @@ -202,16 +144,15 @@ def generate_numba_transform_func(
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)

check_kwargs_and_nopython(kwargs, nopython)
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)

validate_udf(func)
cache_key = (func, "groupby_transform")
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]

numba_func = jit_user_function(func, nopython, nogil, parallel)

numba = import_optional_dependency("numba")

if parallel:
loop_range = numba.prange
else:
Expand Down
42 changes: 12 additions & 30 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,8 @@ def set_use_numba(enable: bool = False) -> None:
GLOBAL_USE_NUMBA = enable


def check_kwargs_and_nopython(
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None
) -> None:
"""
Validate that **kwargs and nopython=True was passed
https://github.com/numba/numba/issues/2916

Parameters
----------
kwargs : dict, default None
user passed keyword arguments to pass into the JITed function
nopython : bool, default None
nopython parameter

Returns
-------
None

Raises
------
NumbaUtilError
"""
if kwargs and nopython:
raise NumbaUtilError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)


def get_jit_arguments(
engine_kwargs: Optional[Dict[str, bool]] = None
engine_kwargs: Optional[Dict[str, bool]] = None, kwargs: Optional[Dict] = None,
) -> Tuple[bool, bool, bool]:
"""
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
Expand All @@ -63,16 +34,27 @@ def get_jit_arguments(
----------
engine_kwargs : dict, default None
user passed keyword arguments for numba.JIT
kwargs : dict, default None
user passed keyword arguments to pass into the JITed function

Returns
-------
(bool, bool, bool)
nopython, nogil, parallel

Raises
------
NumbaUtilError
"""
if engine_kwargs is None:
engine_kwargs = {}

nopython = engine_kwargs.get("nopython", True)
if kwargs and nopython:
raise NumbaUtilError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
return nopython, nogil, parallel
Expand Down
10 changes: 5 additions & 5 deletions pandas/core/window/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import (
check_kwargs_and_nopython,
NUMBA_FUNC_CACHE,
get_jit_arguments,
jit_user_function,
)
Expand Down Expand Up @@ -42,14 +42,14 @@ def generate_numba_apply_func(
-------
Numba function
"""
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
nopython, nogil, parallel = get_jit_arguments(engine_kwargs, kwargs)

check_kwargs_and_nopython(kwargs, nopython)
cache_key = (func, "rolling_apply")
if cache_key in NUMBA_FUNC_CACHE:
return NUMBA_FUNC_CACHE[cache_key]

numba_func = jit_user_function(func, nopython, nogil, parallel)

numba = import_optional_dependency("numba")

if parallel:
loop_range = numba.prange
else:
Expand Down
11 changes: 2 additions & 9 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,14 +1374,7 @@ def apply(
if maybe_use_numba(engine):
if raw is False:
raise ValueError("raw must be `True` when using the numba engine")
cache_key = (func, "rolling_apply")
if cache_key in NUMBA_FUNC_CACHE:
# Return an already compiled version of roll_apply if available
apply_func = NUMBA_FUNC_CACHE[cache_key]
else:
apply_func = generate_numba_apply_func(
args, kwargs, func, engine_kwargs
)
apply_func = generate_numba_apply_func(args, kwargs, func, engine_kwargs)
center = self.center
elif engine in ("cython", None):
if engine_kwargs is not None:
Expand All @@ -1403,7 +1396,7 @@ def apply(
center=center,
floor=0,
name=func,
use_numba_cache=engine == "numba",
use_numba_cache=maybe_use_numba(engine),
raw=raw,
original_func=func,
args=args,
Expand Down