diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 28e0dcc5d9b13..eb637c78806c0 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -626,4 +626,38 @@ def time_first(self): self.df_nans.groupby("key").transform("first") +class TransformEngine: + def setup(self): + N = 10 ** 3 + data = DataFrame( + {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N}, + columns=[0, 1], + ) + self.grouper = data.groupby(0) + + def time_series_numba(self): + def function(values, index): + return values * 5 + + self.grouper[1].transform(function, engine="numba") + + def time_series_cython(self): + def function(values): + return values * 5 + + self.grouper[1].transform(function, engine="cython") + + def time_dataframe_numba(self): + def function(values, index): + return values * 5 + + self.grouper.transform(function, engine="numba") + + def time_dataframe_cython(self): + def function(values): + return values * 5 + + self.grouper.transform(function, engine="cython") + + from .pandas_vb_common import setup # noqa: F401 isort:skip diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 82c43811c0444..2300ef88d2e0d 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -98,6 +98,8 @@ Other enhancements This can be used to set a custom compression level, e.g., ``df.to_csv(path, compression={'method': 'gzip', 'compresslevel': 1}`` (:issue:`33196`) +- :meth:`~pandas.core.groupby.GroupBy.transform` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`) +- .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 13938c41a0f6b..c007d4920cbe7 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -75,6 +75,13 @@ import pandas.core.indexes.base as ibase from pandas.core.internals import BlockManager, make_block from pandas.core.series import Series +from pandas.core.util.numba_ import ( + check_kwargs_and_nopython, + get_jit_arguments, + jit_user_function, + split_for_numba, + validate_udf, +) from pandas.plotting import boxplot_frame_groupby @@ -154,6 +161,8 @@ def pinner(cls): class SeriesGroupBy(GroupBy[Series]): _apply_whitelist = base.series_apply_whitelist + _numba_func_cache: Dict[Callable, Callable] = {} + def _iterate_slices(self) -> Iterable[Series]: yield self._selected_obj @@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs): @Substitution(klass="Series", selected="A.") @Appender(_transform_template) - def transform(self, func, *args, **kwargs): + def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs): func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general(func, *args, **kwargs) + return self._transform_general( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) elif func not in base.transform_kernel_whitelist: msg = f"'{func}' is not a valid function name for transform(name)" @@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs): result = getattr(self, func)(*args, **kwargs) return self._transform_fast(result, func) - def _transform_general(self, func, *args, **kwargs): + def _transform_general( + self, func, *args, engine="cython", engine_kwargs=None, **kwargs + ): """ Transform with a non-str `func`. """ + + if engine == "numba": + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + check_kwargs_and_nopython(kwargs, nopython) + validate_udf(func) + numba_func = self._numba_func_cache.get( + func, jit_user_function(func, nopython, nogil, parallel) + ) + klass = type(self._selected_obj) results = [] for name, group in self: object.__setattr__(group, "name", name) - res = func(group, *args, **kwargs) + if engine == "numba": + values, index = split_for_numba(group) + res = numba_func(values, index, *args) + if func not in self._numba_func_cache: + self._numba_func_cache[func] = numba_func + else: + res = func(group, *args, **kwargs) if isinstance(res, (ABCDataFrame, ABCSeries)): res = res._values @@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]): _apply_whitelist = base.dataframe_apply_whitelist + _numba_func_cache: Dict[Callable, Callable] = {} + _agg_see_also_doc = dedent( """ See Also @@ -1355,19 +1385,35 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False): # Handle cases like BinGrouper return self._concat_objects(keys, values, not_indexed_same=not_indexed_same) - def _transform_general(self, func, *args, **kwargs): + def _transform_general( + self, func, *args, engine="cython", engine_kwargs=None, **kwargs + ): from pandas.core.reshape.concat import concat applied = [] obj = self._obj_with_exclusions gen = self.grouper.get_iterator(obj, axis=self.axis) - fast_path, slow_path = self._define_paths(func, *args, **kwargs) + if engine == "numba": + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + check_kwargs_and_nopython(kwargs, nopython) + validate_udf(func) + numba_func = self._numba_func_cache.get( + func, jit_user_function(func, nopython, nogil, parallel) + ) + else: + fast_path, slow_path = self._define_paths(func, *args, **kwargs) - path = None for name, group in gen: object.__setattr__(group, "name", name) - if path is None: + if engine == "numba": + values, index = split_for_numba(group) + res = numba_func(values, index, *args) + if func not in self._numba_func_cache: + self._numba_func_cache[func] = numba_func + # Return the result as a DataFrame for concatenation later + res = DataFrame(res, index=group.index, columns=group.columns) + else: # Try slow path and fast path. try: path, res = self._choose_path(fast_path, slow_path, group) @@ -1376,8 +1422,6 @@ def _transform_general(self, func, *args, **kwargs): except ValueError as err: msg = "transform must return a scalar value for each group" raise ValueError(msg) from err - else: - res = path(group) if isinstance(res, Series): @@ -1411,13 +1455,15 @@ def _transform_general(self, func, *args, **kwargs): @Substitution(klass="DataFrame", selected="") @Appender(_transform_template) - def transform(self, func, *args, **kwargs): + def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs): # optimized transforms func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general(func, *args, **kwargs) + return self._transform_general( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) elif func not in base.transform_kernel_whitelist: msg = f"'{func}' is not a valid function name for transform(name)" @@ -1439,7 +1485,9 @@ def transform(self, func, *args, **kwargs): ): return self._transform_fast(result, func) - return self._transform_general(func, *args, **kwargs) + return self._transform_general( + func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs + ) def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame: """ diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 873f24b9685e3..154af3981a5ff 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -254,7 +254,36 @@ class providing the base-class of operations. Parameters ---------- f : function - Function to apply to each group + Function to apply to each group. + + Can also accept a Numba JIT function with + ``engine='numba'`` specified. + + If the ``'numba'`` engine is chosen, the function must be + a user defined function with ``values`` and ``index`` as the + first and second arguments respectively in the function signature. + Each group's index will be passed to the user defined function + and optionally available for use. + + .. versionchanged:: 1.1.0 +*args + Positional arguments to pass to func +engine : str, default 'cython' + * ``'cython'`` : Runs the function through C-extensions from cython. + * ``'numba'`` : Runs the function through JIT compiled code from numba. + + .. versionadded:: 1.1.0 +engine_kwargs : dict, default None + * For ``'cython'`` engine, there are no accepted ``engine_kwargs`` + * For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil`` + and ``parallel`` dictionary keys. The values must either be ``True`` or + ``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is + ``{'nopython': True, 'nogil': False, 'parallel': False}`` and will be + applied to the function + + .. versionadded:: 1.1.0 +**kwargs + Keyword arguments to be passed into func. Returns ------- diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index e4debab2c22ee..c5b27b937a05b 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -1,15 +1,36 @@ """Common utilities for Numba operations""" +import inspect import types -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple import numpy as np +from pandas._typing import FrameOrSeries from pandas.compat._optional import import_optional_dependency def check_kwargs_and_nopython( kwargs: Optional[Dict] = None, nopython: Optional[bool] = None -): +) -> None: + """ + Validate that **kwargs and nopython=True was passed + https://github.com/numba/numba/issues/2916 + + Parameters + ---------- + kwargs : dict, default None + user passed keyword arguments to pass into the JITed function + nopython : bool, default None + nopython parameter + + Returns + ------- + None + + Raises + ------ + ValueError + """ if kwargs and nopython: raise ValueError( "numba does not support kwargs with nopython=True: " @@ -17,9 +38,21 @@ def check_kwargs_and_nopython( ) -def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None): +def get_jit_arguments( + engine_kwargs: Optional[Dict[str, bool]] = None +) -> Tuple[bool, bool, bool]: """ Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. + + Parameters + ---------- + engine_kwargs : dict, default None + user passed keyword arguments for numba.JIT + + Returns + ------- + (bool, bool, bool) + nopython, nogil, parallel """ if engine_kwargs is None: engine_kwargs = {} @@ -30,9 +63,28 @@ def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None): return nopython, nogil, parallel -def jit_user_function(func: Callable, nopython: bool, nogil: bool, parallel: bool): +def jit_user_function( + func: Callable, nopython: bool, nogil: bool, parallel: bool +) -> Callable: """ JIT the user's function given the configurable arguments. + + Parameters + ---------- + func : function + user defined function + + nopython : bool + nopython parameter for numba.JIT + nogil : bool + nogil parameter for numba.JIT + parallel : bool + parallel parameter for numba.JIT + + Returns + ------- + function + Numba JITed function """ numba = import_optional_dependency("numba") @@ -56,3 +108,50 @@ def impl(data, *_args): return impl return numba_func + + +def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray]: + """ + Split pandas object into its components as numpy arrays for numba functions. + + Parameters + ---------- + arg : Series or DataFrame + + Returns + ------- + (ndarray, ndarray) + values, index + """ + return arg.to_numpy(), arg.index.to_numpy() + + +def validate_udf(func: Callable) -> None: + """ + Validate user defined function for ops when using Numba. + + The first signature arguments should include: + + def f(values, index, ...): + ... + + Parameters + ---------- + func : function, default False + user defined function + + Returns + ------- + None + """ + udf_signature = list(inspect.signature(func).parameters.keys()) + expected_args = ["values", "index"] + min_number_args = len(expected_args) + if ( + len(udf_signature) < min_number_args + or udf_signature[:min_number_args] != expected_args + ): + raise ValueError( + f"The first {min_number_args} arguments to {func.__name__} must be " + f"{expected_args}" + ) diff --git a/pandas/tests/groupby/conftest.py b/pandas/tests/groupby/conftest.py index 1214734358c80..0b9721968a881 100644 --- a/pandas/tests/groupby/conftest.py +++ b/pandas/tests/groupby/conftest.py @@ -123,3 +123,21 @@ def transformation_func(request): def groupby_func(request): """yields both aggregation and transformation functions.""" return request.param + + +@pytest.fixture(params=[True, False]) +def parallel(request): + """parallel keyword argument for numba.jit""" + return request.param + + +@pytest.fixture(params=[True, False]) +def nogil(request): + """nogil keyword argument for numba.jit""" + return request.param + + +@pytest.fixture(params=[True, False]) +def nopython(request): + """nopython keyword argument for numba.jit""" + return request.param diff --git a/pandas/tests/groupby/transform/__init__.py b/pandas/tests/groupby/transform/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py new file mode 100644 index 0000000000000..96078d0aa3662 --- /dev/null +++ b/pandas/tests/groupby/transform/test_numba.py @@ -0,0 +1,112 @@ +import pytest + +import pandas.util._test_decorators as td + +from pandas import DataFrame +import pandas._testing as tm + + +@td.skip_if_no("numba", "0.46.0") +def test_correct_function_signature(): + def incorrect_function(x): + return x + 1 + + data = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + with pytest.raises(ValueError, match=f"The first 2"): + data.groupby("key").transform(incorrect_function, engine="numba") + + with pytest.raises(ValueError, match=f"The first 2"): + data.groupby("key")["data"].transform(incorrect_function, engine="numba") + + +@td.skip_if_no("numba", "0.46.0") +def test_check_nopython_kwargs(): + def incorrect_function(x, **kwargs): + return x + 1 + + data = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + with pytest.raises(ValueError, match="numba does not support"): + data.groupby("key").transform(incorrect_function, engine="numba", a=1) + + with pytest.raises(ValueError, match="numba does not support"): + data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1) + + +@td.skip_if_no("numba", "0.46.0") +@pytest.mark.filterwarnings("ignore:\\nThe keyword argument") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +@pytest.mark.parametrize("jit", [True, False]) +@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) +def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython): + def func(values, index): + return values + 1 + + if jit: + # Test accepted jitted functions + import numba + + func = numba.jit(func) + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1], + ) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + grouped = data.groupby(0) + if pandas_obj == "Series": + grouped = grouped[1] + + result = grouped.transform(func, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.transform(lambda x: x + 1, engine="cython") + + tm.assert_equal(result, expected) + + +@td.skip_if_no("numba", "0.46.0") +@pytest.mark.filterwarnings("ignore:\\nThe keyword argument") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +@pytest.mark.parametrize("jit", [True, False]) +@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) +def test_cache(jit, pandas_obj, nogil, parallel, nopython): + # Test that the functions are cached correctly if we switch functions + def func_1(values, index): + return values + 1 + + def func_2(values, index): + return values * 5 + + if jit: + import numba + + func_1 = numba.jit(func_1) + func_2 = numba.jit(func_2) + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1], + ) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + grouped = data.groupby(0) + if pandas_obj == "Series": + grouped = grouped[1] + + result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.transform(lambda x: x + 1, engine="cython") + tm.assert_equal(result, expected) + # func_1 should be in the cache now + assert func_1 in grouped._numba_func_cache + + # Add func_2 to the cache + result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.transform(lambda x: x * 5, engine="cython") + tm.assert_equal(result, expected) + assert func_2 in grouped._numba_func_cache + + # Retest func_1 which should use the cache + result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) + expected = grouped.transform(lambda x: x + 1, engine="cython") + tm.assert_equal(result, expected) diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/transform/test_transform.py similarity index 100% rename from pandas/tests/groupby/test_transform.py rename to pandas/tests/groupby/transform/test_transform.py