Skip to content

ENH: Add numba engine to groupby.transform #32854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
db1b3aa
Add numba engine to groupby.transform for series
Mar 18, 2020
e82ede7
new_func -> func
Mar 18, 2020
b0f4faa
Merge remote-tracking branch 'upstream/master' into groupby_transform
Mar 19, 2020
1316662
Adjust inputs for groupby.transform for series
Mar 19, 2020
8d7343e
Merge remote-tracking branch 'upstream/master' into groupby_transform
Mar 20, 2020
3e12a51
Fix typo in func
Mar 20, 2020
6e6891f
Add udf validation function
Mar 20, 2020
73be08d
Add numba engine for dataframe objects
Mar 20, 2020
e85dcd5
Lint
Mar 20, 2020
0760c44
Merge remote-tracking branch 'upstream/master' into groupby_transform
Mar 22, 2020
52fff03
Add separate folder + file for tests
Mar 22, 2020
156a2b4
isort
Mar 22, 2020
6e9d2bf
Add tests and reorder parameters
Mar 22, 2020
2a57a14
Merge remote-tracking branch 'upstream/master' into groupby_transform
Mar 23, 2020
e1e5f73
Remove usused path variable
Mar 23, 2020
ce3b2b3
Make tests more explicit
Mar 23, 2020
195c35f
Black
Mar 23, 2020
d8ea389
Merge remote-tracking branch 'upstream/master' into groupby_transform
Mar 27, 2020
a9ece86
Add numba cache
Mar 27, 2020
367dc12
Add ASV bench
Mar 27, 2020
8256a0a
Add whatsnew enhancement entry
Mar 27, 2020
2c5543d
Lint and add typing
Mar 27, 2020
80e0ddc
Merge remote-tracking branch 'upstream/master' into groupby_transform
Mar 28, 2020
9c4fa56
fix benchmark
Mar 28, 2020
1c71c9b
Fix benchmarks again
Mar 29, 2020
6c0a573
black and fix benchmarks again
Mar 29, 2020
538f4fa
Merge remote-tracking branch 'upstream/master' into groupby_transform
Apr 3, 2020
083ead7
Merge remote-tracking branch 'upstream/master' into groupby_transform
Apr 4, 2020
3132a51
Merge remote-tracking branch 'upstream/master' into groupby_transform
Apr 8, 2020
1de2cf1
Add more typing to numba utils
Apr 8, 2020
d4d58f9
Merge remote-tracking branch 'upstream/master' into groupby_transform
Apr 9, 2020
f63e3e7
Merge remote-tracking branch 'upstream/master' into groupby_transform
Apr 13, 2020
e984283
Add more docstrings
Apr 13, 2020
909e92e
Have benchmark contain more groups
Apr 13, 2020
e2f2a54
Merge remote-tracking branch 'upstream/master' into groupby_transform
Apr 14, 2020
cd7a0be
Remove columns as a required argument for udf
Apr 14, 2020
930466a
Simplify tests
Apr 14, 2020
145ca50
Add one more test and commentary
Apr 14, 2020
fc0654d
Expand docstring
Apr 14, 2020
6d5c63d
Merge remote-tracking branch 'upstream/master' into groupby_transform
Apr 15, 2020
9dbded0
lint
Apr 15, 2020
5909abb
Merge remote-tracking branch 'upstream/master' into groupby_transform
Apr 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,4 +626,38 @@ def time_first(self):
self.df_nans.groupby("key").transform("first")


class TransformEngine:
def setup(self):
N = 10 ** 3
data = DataFrame(
{0: ["a", "a", "b", "b", "a"] * N, 1: [1.0, 2.0, 3.0, 4.0, 5.0] * N},
columns=[0, 1],
)
self.grouper = data.groupby(0)

def time_series_numba(self):
def function(values, index):
return values * 5

self.grouper[1].transform(function, engine="numba")

def time_series_cython(self):
def function(values):
return values * 5

self.grouper[1].transform(function, engine="cython")

def time_dataframe_numba(self):
def function(values, index, columns):
return values * 5

self.grouper.transform(function, engine="numba")

def time_dataframe_cython(self):
def function(values):
return values * 5

self.grouper.transform(function, engine="cython")


from .pandas_vb_common import setup # noqa: F401 isort:skip
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Other enhancements
- :class:`Series.str` now has a `fullmatch` method that matches a regular expression against the entire string in each row of the series, similar to `re.fullmatch` (:issue:`32806`).
- :meth:`DataFrame.sample` will now also allow array-like and BitGenerator objects to be passed to ``random_state`` as seeds (:issue:`32503`)
- :meth:`MultiIndex.union` will now raise `RuntimeWarning` if the object inside are unsortable, pass `sort=False` to suppress this warning (:issue:`33015`)
- :meth:`~pandas.core.groupby.GroupBy.transform` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`)
-

.. ---------------------------------------------------------------------------
Expand Down
74 changes: 61 additions & 13 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@
import pandas.core.indexes.base as ibase
from pandas.core.internals import BlockManager, make_block
from pandas.core.series import Series
from pandas.core.util.numba_ import (
check_kwargs_and_nopython,
get_jit_arguments,
jit_user_function,
split_for_numba,
validate_udf,
)

from pandas.plotting import boxplot_frame_groupby

Expand Down Expand Up @@ -154,6 +161,8 @@ def pinner(cls):
class SeriesGroupBy(GroupBy[Series]):
_apply_whitelist = base.series_apply_whitelist

_numba_func_cache: Dict[Callable, Callable] = {}

def _iterate_slices(self) -> Iterable[Series]:
yield self._selected_obj

Expand Down Expand Up @@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs):

@Substitution(klass="Series", selected="A.")
@Appender(_transform_template)
def transform(self, func, *args, **kwargs):
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
func = self._get_cython_func(func) or func

if not isinstance(func, str):
return self._transform_general(func, *args, **kwargs)
return self._transform_general(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

elif func not in base.transform_kernel_whitelist:
msg = f"'{func}' is not a valid function name for transform(name)"
Expand All @@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs):
result = getattr(self, func)(*args, **kwargs)
return self._transform_fast(result, func)

def _transform_general(self, func, *args, **kwargs):
def _transform_general(
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
):
"""
Transform with a non-str `func`.
"""

if engine == "numba":
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
check_kwargs_and_nopython(kwargs, nopython)
validate_udf(func)
numba_func = self._numba_func_cache.get(
func, jit_user_function(func, nopython, nogil, parallel)
)

klass = type(self._selected_obj)

results = []
for name, group in self:
object.__setattr__(group, "name", name)
res = func(group, *args, **kwargs)
if engine == "numba":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like to see this as

def _evaluate_udf

values, index, _ = split_for_numba(group)
res = numba_func(values, index, *args)
if func not in self._numba_func_cache:
self._numba_func_cache[func] = numba_func
else:
res = func(group, *args, **kwargs)

if isinstance(res, (ABCDataFrame, ABCSeries)):
res = res._values
Expand Down Expand Up @@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]):

_apply_whitelist = base.dataframe_apply_whitelist

_numba_func_cache: Dict[Callable, Callable] = {}

_agg_see_also_doc = dedent(
"""
See Also
Expand Down Expand Up @@ -1359,19 +1389,35 @@ def first_not_none(values):
# Handle cases like BinGrouper
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)

def _transform_general(self, func, *args, **kwargs):
def _transform_general(
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
):
from pandas.core.reshape.concat import concat

applied = []
obj = self._obj_with_exclusions
gen = self.grouper.get_iterator(obj, axis=self.axis)
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
if engine == "numba":
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
check_kwargs_and_nopython(kwargs, nopython)
validate_udf(func, include_columns=True)
numba_func = self._numba_func_cache.get(
func, jit_user_function(func, nopython, nogil, parallel)
)
else:
fast_path, slow_path = self._define_paths(func, *args, **kwargs)

path = None
for name, group in gen:
object.__setattr__(group, "name", name)

if path is None:
if engine == "numba":
values, index, columns = split_for_numba(group)
res = numba_func(values, index, columns, *args)
if func not in self._numba_func_cache:
self._numba_func_cache[func] = numba_func
# Return the result as a DataFrame for concatenation later
res = DataFrame(res, index=group.index, columns=group.columns)
else:
# Try slow path and fast path.
try:
path, res = self._choose_path(fast_path, slow_path, group)
Expand All @@ -1380,8 +1426,6 @@ def _transform_general(self, func, *args, **kwargs):
except ValueError as err:
msg = "transform must return a scalar value for each group"
raise ValueError(msg) from err
else:
res = path(group)

if isinstance(res, Series):

Expand Down Expand Up @@ -1415,13 +1459,15 @@ def _transform_general(self, func, *args, **kwargs):

@Substitution(klass="DataFrame", selected="")
@Appender(_transform_template)
def transform(self, func, *args, **kwargs):
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):

# optimized transforms
func = self._get_cython_func(func) or func

if not isinstance(func, str):
return self._transform_general(func, *args, **kwargs)
return self._transform_general(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

elif func not in base.transform_kernel_whitelist:
msg = f"'{func}' is not a valid function name for transform(name)"
Expand All @@ -1443,7 +1489,9 @@ def transform(self, func, *args, **kwargs):
):
return self._transform_fast(result, func)

return self._transform_general(func, *args, **kwargs)
return self._transform_general(
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs
)

def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
"""
Expand Down
55 changes: 51 additions & 4 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
"""Common utilities for Numba operations"""
import inspect
import types
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Tuple

import numpy as np

from pandas._typing import FrameOrSeries
from pandas.compat._optional import import_optional_dependency


def check_kwargs_and_nopython(
kwargs: Optional[Dict] = None, nopython: Optional[bool] = None
):
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a doc-string

if kwargs and nopython:
raise ValueError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)


def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None):
def get_jit_arguments(
engine_kwargs: Optional[Dict[str, bool]] = None
) -> Tuple[bool, bool, bool]:
"""
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
"""
Expand All @@ -30,7 +34,9 @@ def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None):
return nopython, nogil, parallel


def jit_user_function(func: Callable, nopython: bool, nogil: bool, parallel: bool):
def jit_user_function(
func: Callable, nopython: bool, nogil: bool, parallel: bool
) -> Callable:
"""
JIT the user's function given the configurable arguments.
"""
Expand All @@ -56,3 +62,44 @@ def impl(data, *_args):
return impl

return numba_func


def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Split pandas object into its components as numpy arrays for numba functions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add Parameters / Returns section

"""
if getattr(arg, "columns", None) is not None:
columns_as_array = arg.columns.to_numpy()
else:
columns_as_array = None
return arg.to_numpy(), arg.index.to_numpy(), columns_as_array


def validate_udf(func: Callable, include_columns: bool = False) -> None:
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Validate user defined function for ops when using Numba.

For routines that pass Series objects, the first signature arguments should include:

def f(values, index, ...):
...

For routines that pass DataFrame objects, the first signature arguments should
include:

def f(values, index, columns, ...):
...
"""
udf_signature = list(inspect.signature(func).parameters.keys())
expected_args = ["values", "index"]
if include_columns:
expected_args.append("columns")
min_number_args = len(expected_args)
if (
len(udf_signature) < min_number_args
or udf_signature[:min_number_args] != expected_args
):
raise ValueError(
f"The first {min_number_args} arguments to {func.__name__} must be "
f"{expected_args}"
)
18 changes: 18 additions & 0 deletions pandas/tests/groupby/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,21 @@ def transformation_func(request):
def groupby_func(request):
"""yields both aggregation and transformation functions."""
return request.param


@pytest.fixture(params=[True, False])
def parallel(request):
"""parallel keyword argument for numba.jit"""
return request.param


@pytest.fixture(params=[True, False])
def nogil(request):
"""nogil keyword argument for numba.jit"""
return request.param


@pytest.fixture(params=[True, False])
def nopython(request):
"""nopython keyword argument for numba.jit"""
return request.param
Empty file.
Loading