From db1b3aa17ab633b46d24970ec5126fbcc4f08f23 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Tue, 17 Mar 2020 22:02:14 -0700 Subject: [PATCH 01/28] Add numba engine to groupby.transform for series --- pandas/core/groupby/generic.py | 39 +++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 4102b8527b6aa..b70fa72546214 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -76,6 +76,9 @@ from pandas.plotting import boxplot_frame_groupby +# from pandas.core.util.numba_ import check_kwargs_and_nopython, get_jit_arguments, jit_user_function, split_for_numba + + if TYPE_CHECKING: from pandas.core.internals import Block @@ -461,11 +464,13 @@ def _aggregate_named(self, func, *args, **kwargs): @Substitution(klass="Series", selected="A.") @Appender(_transform_template) - def transform(self, func, *args, **kwargs): + def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general(func, *args, **kwargs) + return self._transform_general( + func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs + ) elif func not in base.transform_kernel_whitelist: msg = f"'{func}' is not a valid function name for transform(name)" @@ -480,16 +485,28 @@ def transform(self, func, *args, **kwargs): result = getattr(self, func)(*args, **kwargs) return self._transform_fast(result, func) - def _transform_general(self, func, *args, **kwargs): + def _transform_general( + self, func, engine="cython", engine_kwargs=None, *args, **kwargs + ): """ Transform with a non-str `func`. """ + + if engine == "numba": + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + check_kwargs_and_nopython(kwargs, nopython) + new_func = jit_user_function(func, nopython, nogil, parallel) + klass = type(self._selected_obj) results = [] for name, group in self: object.__setattr__(group, "name", name) - res = func(group, *args, **kwargs) + if engine == "numba": + values, index, _ = split_for_numba(group) + res = func(group, index, *args) + else: + res = func(group, *args, **kwargs) if isinstance(res, (ABCDataFrame, ABCSeries)): res = res._values @@ -1355,7 +1372,9 @@ def first_not_none(values): # Handle cases like BinGrouper return self._concat_objects(keys, values, not_indexed_same=not_indexed_same) - def _transform_general(self, func, *args, **kwargs): + def _transform_general( + self, func, engine="cython", engine_kwargs=None, *args, **kwargs + ): from pandas.core.reshape.concat import concat applied = [] @@ -1411,13 +1430,15 @@ def _transform_general(self, func, *args, **kwargs): @Substitution(klass="DataFrame", selected="") @Appender(_transform_template) - def transform(self, func, *args, **kwargs): + def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): # optimized transforms func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general(func, *args, **kwargs) + return self._transform_general( + func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs + ) elif func not in base.transform_kernel_whitelist: msg = f"'{func}' is not a valid function name for transform(name)" @@ -1439,7 +1460,9 @@ def transform(self, func, *args, **kwargs): ): return self._transform_fast(result, func) - return self._transform_general(func, *args, **kwargs) + return self._transform_general( + func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs + ) def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame: """ From e82ede77387011baa5fc6f1b296a533eba57d795 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Tue, 17 Mar 2020 22:02:38 -0700 Subject: [PATCH 02/28] new_func -> func --- pandas/core/groupby/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index b70fa72546214..bc00029726bad 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -495,7 +495,7 @@ def _transform_general( if engine == "numba": nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) - new_func = jit_user_function(func, nopython, nogil, parallel) + func = jit_user_function(func, nopython, nogil, parallel) klass = type(self._selected_obj) From 1316662353433f3361bd9fcd8eeedb38696db0aa Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Wed, 18 Mar 2020 21:24:56 -0700 Subject: [PATCH 03/28] Adjust inputs for groupby.transform for series --- pandas/core/groupby/generic.py | 7 ++++++- pandas/core/util/numba_.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index bc00029726bad..95182a8ed01ad 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -76,7 +76,12 @@ from pandas.plotting import boxplot_frame_groupby -# from pandas.core.util.numba_ import check_kwargs_and_nopython, get_jit_arguments, jit_user_function, split_for_numba +from pandas.core.util.numba_ import ( + check_kwargs_and_nopython, + get_jit_arguments, + jit_user_function, + split_for_numba, +) if TYPE_CHECKING: diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index e4debab2c22ee..d664717e28073 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -5,6 +5,7 @@ import numpy as np from pandas.compat._optional import import_optional_dependency +from pandas._typing import FrameOrSeries def check_kwargs_and_nopython( @@ -56,3 +57,14 @@ def impl(data, *_args): return impl return numba_func + + +def split_for_numba(arg: FrameOrSeries): + """ + Split pandas object into its components as numpy arrays for numba functions. + """ + if getattr(arg, "columns", None) is not None: + columns_as_array = arg.columns.to_numpy() + else: + columns_as_array = None + return arg.to_numpy(), arg.index.to_numpy(), columns_as_array From 3e12a51b130ea508e12d02c2276f4aac6a4e336d Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Thu, 19 Mar 2020 18:18:57 -0700 Subject: [PATCH 04/28] Fix typo in func --- pandas/core/groupby/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 95182a8ed01ad..2e0b741e03726 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -509,7 +509,7 @@ def _transform_general( object.__setattr__(group, "name", name) if engine == "numba": values, index, _ = split_for_numba(group) - res = func(group, index, *args) + res = func(values, index, *args) else: res = func(group, *args, **kwargs) From 6e6891fe2ecce9f7b318c029c851448fe609cd1f Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Thu, 19 Mar 2020 18:42:13 -0700 Subject: [PATCH 05/28] Add udf validation function --- pandas/core/groupby/generic.py | 2 ++ pandas/core/util/numba_.py | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 2e0b741e03726..54b03745213b0 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -81,6 +81,7 @@ get_jit_arguments, jit_user_function, split_for_numba, + validate_udf, ) @@ -500,6 +501,7 @@ def _transform_general( if engine == "numba": nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) + validate_udf(func) func = jit_user_function(func, nopython, nogil, parallel) klass = type(self._selected_obj) diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index d664717e28073..527bf942ec743 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -1,4 +1,5 @@ """Common utilities for Numba operations""" +import inspect import types from typing import Callable, Dict, Optional @@ -68,3 +69,29 @@ def split_for_numba(arg: FrameOrSeries): else: columns_as_array = None return arg.to_numpy(), arg.index.to_numpy(), columns_as_array + + +def validate_udf(func: Callable, include_columns: bool = False): + """ + Validate user defined function for ops when using Numba. + + For routines that pass Series objects, the first signature arguments should include: + def f(values, index, ...): + ... + + For routines that pass DataFrame objects, the first signature arguments should include: + def f(values, index, columns, ...): + ... + """ + udf_signature = list(inspect.signature(func).parameters.keys()) + expected_args = ["values", "index"] + if include_columns: + expected_args.append("columns") + min_number_args = len(expected_args) + if ( + len(udf_signature) < min_number_args + or udf_signature[:min_number_args] != expected_args + ): + raise ValueError( + f"The first {min_number_args} arguments to {func.__name__} must be {expected_args}" + ) From 73be08d2c405b04846cd3c2d8a295e01df10549c Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Thu, 19 Mar 2020 19:15:21 -0700 Subject: [PATCH 06/28] Add numba engine for dataframe objects --- pandas/core/groupby/generic.py | 36 ++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 54b03745213b0..67ccddceec648 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1387,23 +1387,35 @@ def _transform_general( applied = [] obj = self._obj_with_exclusions gen = self.grouper.get_iterator(obj, axis=self.axis) - fast_path, slow_path = self._define_paths(func, *args, **kwargs) + if engine == "numba": + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + check_kwargs_and_nopython(kwargs, nopython) + validate_udf(func) + func = jit_user_function(func, nopython, nogil, parallel) + else: + path = None + fast_path, slow_path = self._define_paths(func, *args, **kwargs) - path = None for name, group in gen: object.__setattr__(group, "name", name) - if path is None: - # Try slow path and fast path. - try: - path, res = self._choose_path(fast_path, slow_path, group) - except TypeError: - return self._transform_item_by_item(obj, fast_path) - except ValueError as err: - msg = "transform must return a scalar value for each group" - raise ValueError(msg) from err + if engine == "numba": + values, index, columns = split_for_numba(group) + res = func(values, index, columns, *args) + # Return the result as a DataFrame for concatenation later + res = DataFrame(res, index=group.index, columns=group.columns) else: - res = path(group) + if path is None: + # Try slow path and fast path. + try: + path, res = self._choose_path(fast_path, slow_path, group) + except TypeError: + return self._transform_item_by_item(obj, fast_path) + except ValueError as err: + msg = "transform must return a scalar value for each group" + raise ValueError(msg) from err + else: + res = path(group) if isinstance(res, Series): From e85dcd58b0968f0e5b7631ed973f3e120c965ade Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Thu, 19 Mar 2020 21:41:09 -0700 Subject: [PATCH 07/28] Lint --- pandas/core/util/numba_.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index 527bf942ec743..111a1eefb8fdb 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -76,10 +76,13 @@ def validate_udf(func: Callable, include_columns: bool = False): Validate user defined function for ops when using Numba. For routines that pass Series objects, the first signature arguments should include: + def f(values, index, ...): ... - For routines that pass DataFrame objects, the first signature arguments should include: + For routines that pass DataFrame objects, the first signature arguments should + include: + def f(values, index, columns, ...): ... """ @@ -93,5 +96,6 @@ def f(values, index, columns, ...): or udf_signature[:min_number_args] != expected_args ): raise ValueError( - f"The first {min_number_args} arguments to {func.__name__} must be {expected_args}" + f"The first {min_number_args} arguments to {func.__name__} must be " + f"{expected_args}" ) From 52fff03f9abced5d8586403c0e51c0151efbb137 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 21 Mar 2020 22:46:59 -0700 Subject: [PATCH 08/28] Add separate folder + file for tests --- pandas/tests/groupby/transform/__init__.py | 0 pandas/tests/groupby/transform/test_numba.py | 11 +++++++++++ .../tests/groupby/{ => transform}/test_transform.py | 0 3 files changed, 11 insertions(+) create mode 100644 pandas/tests/groupby/transform/__init__.py create mode 100644 pandas/tests/groupby/transform/test_numba.py rename pandas/tests/groupby/{ => transform}/test_transform.py (100%) diff --git a/pandas/tests/groupby/transform/__init__.py b/pandas/tests/groupby/transform/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py new file mode 100644 index 0000000000000..5183f8e471c8b --- /dev/null +++ b/pandas/tests/groupby/transform/test_numba.py @@ -0,0 +1,11 @@ +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import DataFrame, Series +import pandas._testing as tm + +@td.skip_if_no("numba", "0.46.0") +def test_correct_function_signature(): + pass \ No newline at end of file diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/transform/test_transform.py similarity index 100% rename from pandas/tests/groupby/test_transform.py rename to pandas/tests/groupby/transform/test_transform.py From 156a2b47fe025f6c47ebb1998145e5ae832be239 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 21 Mar 2020 22:47:32 -0700 Subject: [PATCH 09/28] isort --- pandas/core/groupby/generic.py | 4 +--- pandas/core/util/numba_.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 67ccddceec648..164a3a9856c07 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -73,9 +73,6 @@ import pandas.core.indexes.base as ibase from pandas.core.internals import BlockManager, make_block from pandas.core.series import Series - -from pandas.plotting import boxplot_frame_groupby - from pandas.core.util.numba_ import ( check_kwargs_and_nopython, get_jit_arguments, @@ -84,6 +81,7 @@ validate_udf, ) +from pandas.plotting import boxplot_frame_groupby if TYPE_CHECKING: from pandas.core.internals import Block diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index 111a1eefb8fdb..320b579435738 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -5,8 +5,8 @@ import numpy as np -from pandas.compat._optional import import_optional_dependency from pandas._typing import FrameOrSeries +from pandas.compat._optional import import_optional_dependency def check_kwargs_and_nopython( From 6e9d2bf509df4dd7ee79168b5a95f61dc34415d0 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 21 Mar 2020 23:31:17 -0700 Subject: [PATCH 10/28] Add tests and reorder parameters --- pandas/core/groupby/generic.py | 14 ++-- pandas/tests/groupby/conftest.py | 18 +++++ pandas/tests/groupby/transform/test_numba.py | 72 +++++++++++++++++++- 3 files changed, 94 insertions(+), 10 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 164a3a9856c07..bdd5204df6424 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -468,12 +468,12 @@ def _aggregate_named(self, func, *args, **kwargs): @Substitution(klass="Series", selected="A.") @Appender(_transform_template) - def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): + def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs): func = self._get_cython_func(func) or func if not isinstance(func, str): return self._transform_general( - func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs ) elif func not in base.transform_kernel_whitelist: @@ -490,7 +490,7 @@ def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): return self._transform_fast(result, func) def _transform_general( - self, func, engine="cython", engine_kwargs=None, *args, **kwargs + self, func, *args, engine="cython", engine_kwargs=None, **kwargs ): """ Transform with a non-str `func`. @@ -1378,7 +1378,7 @@ def first_not_none(values): return self._concat_objects(keys, values, not_indexed_same=not_indexed_same) def _transform_general( - self, func, engine="cython", engine_kwargs=None, *args, **kwargs + self, func, *args, engine="cython", engine_kwargs=None, **kwargs ): from pandas.core.reshape.concat import concat @@ -1388,7 +1388,7 @@ def _transform_general( if engine == "numba": nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) - validate_udf(func) + validate_udf(func, include_columns=True) func = jit_user_function(func, nopython, nogil, parallel) else: path = None @@ -1447,14 +1447,14 @@ def _transform_general( @Substitution(klass="DataFrame", selected="") @Appender(_transform_template) - def transform(self, func, engine="cython", engine_kwargs=None, *args, **kwargs): + def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs): # optimized transforms func = self._get_cython_func(func) or func if not isinstance(func, str): return self._transform_general( - func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs ) elif func not in base.transform_kernel_whitelist: diff --git a/pandas/tests/groupby/conftest.py b/pandas/tests/groupby/conftest.py index 1214734358c80..0b9721968a881 100644 --- a/pandas/tests/groupby/conftest.py +++ b/pandas/tests/groupby/conftest.py @@ -123,3 +123,21 @@ def transformation_func(request): def groupby_func(request): """yields both aggregation and transformation functions.""" return request.param + + +@pytest.fixture(params=[True, False]) +def parallel(request): + """parallel keyword argument for numba.jit""" + return request.param + + +@pytest.fixture(params=[True, False]) +def nogil(request): + """nogil keyword argument for numba.jit""" + return request.param + + +@pytest.fixture(params=[True, False]) +def nopython(request): + """nopython keyword argument for numba.jit""" + return request.param diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 5183f8e471c8b..2d8d094e12413 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -1,11 +1,77 @@ -import numpy as np import pytest import pandas.util._test_decorators as td -from pandas import DataFrame, Series +from pandas import DataFrame import pandas._testing as tm + @td.skip_if_no("numba", "0.46.0") def test_correct_function_signature(): - pass \ No newline at end of file + def incorrect_function(x): + return x + 1 + + data = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + with pytest.raises(ValueError, match=f"The first 3"): + data.groupby("key").transform(incorrect_function, engine="numba") + + with pytest.raises(ValueError, match=f"The first 2"): + data.groupby("key")["data"].transform(incorrect_function, engine="numba") + + +@td.skip_if_no("numba", "0.46.0") +def test_check_nopython_kwargs(): + def incorrect_function(x, **kwargs): + return x + 1 + + data = DataFrame( + {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, + columns=["key", "data"], + ) + with pytest.raises(ValueError, match="numba does not support"): + data.groupby("key").transform(incorrect_function, engine="numba", a=1) + + with pytest.raises(ValueError, match="numba does not support"): + data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1) + + +@td.skip_if_no("numba", "0.46.0") +@pytest.mark.filterwarnings("ignore:\\nThe keyword argument") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +@pytest.mark.parametrize("jit", [True, False]) +def test_numba_vs_cython(jit, nogil, parallel, nopython): + def series_func(values, index): + return values + 1 + + def dataframe_func(values, index, columns): + return values + 1 + + if jit: + # Test accepted jitted functions + import numba + + series_func = numba.jit(series_func) + dataframe_func = numba.jit(dataframe_func) + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1], + ) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + grouped = data.groupby(0) + + result = grouped.transform( + dataframe_func, engine="numba", engine_kwargs=engine_kwargs + ) + expected = grouped.transform(lambda x: x + 1) + + tm.assert_frame_equal(result, expected) + + result = grouped[1].transform( + series_func, engine="numba", engine_kwargs=engine_kwargs + ) + expected = grouped[1].transform(lambda x: x + 1) + + tm.assert_series_equal(result, expected) From e1e5f73832a7125a511f9cfd5d4acfd3865773ea Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 23 Mar 2020 14:21:28 -0700 Subject: [PATCH 11/28] Remove usused path variable --- pandas/core/groupby/generic.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index bdd5204df6424..1b7bc30340c8d 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1391,7 +1391,6 @@ def _transform_general( validate_udf(func, include_columns=True) func = jit_user_function(func, nopython, nogil, parallel) else: - path = None fast_path, slow_path = self._define_paths(func, *args, **kwargs) for name, group in gen: @@ -1403,17 +1402,14 @@ def _transform_general( # Return the result as a DataFrame for concatenation later res = DataFrame(res, index=group.index, columns=group.columns) else: - if path is None: - # Try slow path and fast path. - try: - path, res = self._choose_path(fast_path, slow_path, group) - except TypeError: - return self._transform_item_by_item(obj, fast_path) - except ValueError as err: - msg = "transform must return a scalar value for each group" - raise ValueError(msg) from err - else: - res = path(group) + # Try slow path and fast path. + try: + path, res = self._choose_path(fast_path, slow_path, group) + except TypeError: + return self._transform_item_by_item(obj, fast_path) + except ValueError as err: + msg = "transform must return a scalar value for each group" + raise ValueError(msg) from err if isinstance(res, Series): From ce3b2b3ca6e6d2095b6c34afe5d63a3b70b7c861 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 23 Mar 2020 14:22:32 -0700 Subject: [PATCH 12/28] Make tests more explicit --- pandas/tests/groupby/transform/test_numba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 2d8d094e12413..4f53036e08fe2 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -65,13 +65,13 @@ def dataframe_func(values, index, columns): result = grouped.transform( dataframe_func, engine="numba", engine_kwargs=engine_kwargs ) - expected = grouped.transform(lambda x: x + 1) + expected = grouped.transform(lambda x: x + 1, engine='cython') tm.assert_frame_equal(result, expected) result = grouped[1].transform( series_func, engine="numba", engine_kwargs=engine_kwargs ) - expected = grouped[1].transform(lambda x: x + 1) + expected = grouped[1].transform(lambda x: x + 1, engine='cython') tm.assert_series_equal(result, expected) From 195c35f9fa1d9006e78a51bb181132cf8aaf3da1 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 23 Mar 2020 14:48:38 -0700 Subject: [PATCH 13/28] Black --- pandas/tests/groupby/transform/test_numba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 4f53036e08fe2..757281b6a6a24 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -65,13 +65,13 @@ def dataframe_func(values, index, columns): result = grouped.transform( dataframe_func, engine="numba", engine_kwargs=engine_kwargs ) - expected = grouped.transform(lambda x: x + 1, engine='cython') + expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_frame_equal(result, expected) result = grouped[1].transform( series_func, engine="numba", engine_kwargs=engine_kwargs ) - expected = grouped[1].transform(lambda x: x + 1, engine='cython') + expected = grouped[1].transform(lambda x: x + 1, engine="cython") tm.assert_series_equal(result, expected) From a9ece86927f1abff5a3fd82f751e6016992a1d7e Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Thu, 26 Mar 2020 21:49:48 -0700 Subject: [PATCH 14/28] Add numba cache --- pandas/core/groupby/generic.py | 20 ++++- pandas/tests/groupby/transform/test_numba.py | 93 ++++++++++++++++++++ 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 498136ddf1a86..dc17a6cf248c9 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -161,6 +161,8 @@ def pinner(cls): class SeriesGroupBy(GroupBy): _apply_whitelist = base.series_apply_whitelist + _numba_func_cache = {} + def _iterate_slices(self) -> Iterable[Series]: yield self._selected_obj @@ -502,7 +504,9 @@ def _transform_general( nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) validate_udf(func) - func = jit_user_function(func, nopython, nogil, parallel) + numba_func = self._numba_func_cache.get( + func, jit_user_function(func, nopython, nogil, parallel) + ) klass = type(self._selected_obj) @@ -511,7 +515,9 @@ def _transform_general( object.__setattr__(group, "name", name) if engine == "numba": values, index, _ = split_for_numba(group) - res = func(values, index, *args) + res = numba_func(values, index, *args) + if func not in self._numba_func_cache: + self._numba_func_cache[func] = numba_func else: res = func(group, *args, **kwargs) @@ -841,6 +847,8 @@ class DataFrameGroupBy(GroupBy): _apply_whitelist = base.dataframe_apply_whitelist + _numba_func_cache = {} + _agg_see_also_doc = dedent( """ See Also @@ -1393,7 +1401,9 @@ def _transform_general( nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) validate_udf(func, include_columns=True) - func = jit_user_function(func, nopython, nogil, parallel) + numba_func = self._numba_func_cache.get( + func, jit_user_function(func, nopython, nogil, parallel) + ) else: fast_path, slow_path = self._define_paths(func, *args, **kwargs) @@ -1402,7 +1412,9 @@ def _transform_general( if engine == "numba": values, index, columns = split_for_numba(group) - res = func(values, index, columns, *args) + res = numba_func(values, index, columns, *args) + if func not in self._numba_func_cache: + self._numba_func_cache[func] = numba_func # Return the result as a DataFrame for concatenation later res = DataFrame(res, index=group.index, columns=group.columns) else: diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 757281b6a6a24..506a6065bfd19 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -75,3 +75,96 @@ def dataframe_func(values, index, columns): expected = grouped[1].transform(lambda x: x + 1, engine="cython") tm.assert_series_equal(result, expected) + + +@td.skip_if_no("numba", "0.46.0") +@pytest.mark.filterwarnings("ignore:\\nThe keyword argument") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +@pytest.mark.parametrize("jit", [True, False]) +def test_cache_series(jit, nogil, parallel, nopython): + # Test that the functions are cached correctly if we switch functions + def series_func_1(values, index): + return values + 1 + + def series_func_2(values, index): + return values * 5 + + if jit: + import numba + + series_func_1 = numba.jit(series_func_1) + series_func_2 = numba.jit(series_func_2) + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1], + ) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + grouped = data.groupby(0) + + result = grouped[1].transform( + series_func_1, engine="numba", engine_kwargs=engine_kwargs + ) + expected = grouped[1].transform(lambda x: x + 1, engine="cython") + tm.assert_series_equal(result, expected) + # series_func_1 should be in the cache now + assert series_func_1 in grouped[1]._numba_func_cache + + # Add dataframe_func_2 to the cache + result = grouped[1].transform( + series_func_2, engine="numba", engine_kwargs=engine_kwargs + ) + expected = grouped[1].transform(lambda x: x * 5, engine="cython") + tm.assert_series_equal(result, expected) + + result = grouped[1].transform( + series_func_1, engine="numba", engine_kwargs=engine_kwargs + ) + expected = grouped[1].transform(lambda x: x + 1, engine="cython") + tm.assert_series_equal(result, expected) + + +@td.skip_if_no("numba", "0.46.0") +@pytest.mark.filterwarnings("ignore:\\nThe keyword argument") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +@pytest.mark.parametrize("jit", [True, False]) +def test_cache_dataframe(jit, nogil, parallel, nopython): + # Test that the functions are cached correctly if we switch functions + def dataframe_func_1(values, index, columns): + return values + 1 + + def dataframe_func_2(values, index, columns): + return values * 5 + + if jit: + import numba + + dataframe_func_1 = numba.jit(dataframe_func_1) + dataframe_func_2 = numba.jit(dataframe_func_2) + + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1], + ) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + grouped = data.groupby(0) + + result = grouped.transform( + dataframe_func_1, engine="numba", engine_kwargs=engine_kwargs + ) + expected = grouped.transform(lambda x: x + 1, engine="cython") + tm.assert_frame_equal(result, expected) + # dataframe_func_1 should be in the cache now + assert dataframe_func_1 in grouped._numba_func_cache + + # Add dataframe_func_2 to the cache + result = grouped.transform( + dataframe_func_2, engine="numba", engine_kwargs=engine_kwargs + ) + expected = grouped.transform(lambda x: x * 5, engine="cython") + tm.assert_frame_equal(result, expected) + + # This run should use the cached dataframe_func_1 + result = grouped.transform( + dataframe_func_1, engine="numba", engine_kwargs=engine_kwargs + ) + expected = grouped.transform(lambda x: x + 1, engine="cython") + tm.assert_frame_equal(result, expected) From 367dc124def8058639688274f2e5416d41497e04 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Thu, 26 Mar 2020 21:58:13 -0700 Subject: [PATCH 15/28] Add ASV bench --- asv_bench/benchmarks/groupby.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 28e0dcc5d9b13..ecd2ab18e1fff 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -626,4 +626,28 @@ def time_first(self): self.df_nans.groupby("key").transform("first") +class TransformEngine: + + params = ( + [np.sum, lambda x: np.sum(x) + 5], + ["cython", "numba"], + ) + param_names = ["function", "engine"] + + def setup(self, function, engine): + N = 10 ** 3 + arr = 100 * np.random.random(N) + data = DataFrame( + {0: ["a", "a", "b", "b", "a"] * N, 1: [1.0, 2.0, 3.0, 4.0, 5.0] * N}, + columns=[0, 1], + ) + self.grouper = data.groupby(0) + + def time_series(self, function, engine): + self.grouper[0].transform(function, engine=engine) + + def time_dataframe(self, function, engine): + self.grouper.transform(function, engine=engine) + + from .pandas_vb_common import setup # noqa: F401 isort:skip From 8256a0a5265298025c61686145d2ffb166b6fe75 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Thu, 26 Mar 2020 22:01:31 -0700 Subject: [PATCH 16/28] Add whatsnew enhancement entry --- doc/source/whatsnew/v1.1.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 7920820b32620..93f1a6983b8ea 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -71,6 +71,7 @@ Other enhancements - Positional slicing on a :class:`IntervalIndex` now supports slices with ``step > 1`` (:issue:`31658`) - :class:`Series.str` now has a `fullmatch` method that matches a regular expression against the entire string in each row of the series, similar to `re.fullmatch` (:issue:`32806`). - :meth:`DataFrame.sample` will now also allow array-like and BitGenerator objects to be passed to ``random_state`` as seeds (:issue:`32503`) +- :meth:`~pandas.core.groupby.GroupBy.transform` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`) - .. --------------------------------------------------------------------------- From 2c5543ddbb13db4bb24094b42dcfc09b6e83b038 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Fri, 27 Mar 2020 11:18:44 -0700 Subject: [PATCH 17/28] Lint and add typing --- asv_bench/benchmarks/groupby.py | 1 - pandas/core/groupby/generic.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index ecd2ab18e1fff..45cfd5df26081 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -636,7 +636,6 @@ class TransformEngine: def setup(self, function, engine): N = 10 ** 3 - arr = 100 * np.random.random(N) data = DataFrame( {0: ["a", "a", "b", "b", "a"] * N, 1: [1.0, 2.0, 3.0, 4.0, 5.0] * N}, columns=[0, 1], diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index dc17a6cf248c9..38551641f3756 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -161,7 +161,7 @@ def pinner(cls): class SeriesGroupBy(GroupBy): _apply_whitelist = base.series_apply_whitelist - _numba_func_cache = {} + _numba_func_cache: Dict[Callable, Callable] = {} def _iterate_slices(self) -> Iterable[Series]: yield self._selected_obj @@ -847,7 +847,7 @@ class DataFrameGroupBy(GroupBy): _apply_whitelist = base.dataframe_apply_whitelist - _numba_func_cache = {} + _numba_func_cache: Dict[Callable, Callable] = {} _agg_see_also_doc = dedent( """ From 9c4fa567748c39a6777a9a0a3e71d5da510da97e Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 28 Mar 2020 15:03:56 -0700 Subject: [PATCH 18/28] fix benchmark --- asv_bench/benchmarks/groupby.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 45cfd5df26081..8cabae0cf572f 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -628,13 +628,10 @@ def time_first(self): class TransformEngine: - params = ( - [np.sum, lambda x: np.sum(x) + 5], - ["cython", "numba"], - ) - param_names = ["function", "engine"] + params = ["cython", "numba"] + param_names = ["engine"] - def setup(self, function, engine): + def setup(self, engine): N = 10 ** 3 data = DataFrame( {0: ["a", "a", "b", "b", "a"] * N, 1: [1.0, 2.0, 3.0, 4.0, 5.0] * N}, @@ -642,10 +639,16 @@ def setup(self, function, engine): ) self.grouper = data.groupby(0) - def time_series(self, function, engine): + def time_series(self, engine): + def function(values, index): + return values * 5 + self.grouper[0].transform(function, engine=engine) - def time_dataframe(self, function, engine): + def time_dataframe(self, engine): + def function(values, index, columns): + return values * 5 + self.grouper.transform(function, engine=engine) From 1c71c9b54ce7de24546bbdad521e235425f52c9f Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 28 Mar 2020 17:58:33 -0700 Subject: [PATCH 19/28] Fix benchmarks again --- asv_bench/benchmarks/groupby.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 8cabae0cf572f..09721d561d738 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -639,17 +639,29 @@ def setup(self, engine): ) self.grouper = data.groupby(0) - def time_series(self, engine): + def time_series_numba(self, engine): def function(values, index): return values * 5 - self.grouper[0].transform(function, engine=engine) + self.grouper[0].transform(function, engine='numba') - def time_dataframe(self, engine): + def time_series_cython(self, engine): + def function(values): + return values * 5 + + self.grouper[0].transform(function, engine='cython') + + def time_dataframe_numba(self, engine): def function(values, index, columns): return values * 5 - self.grouper.transform(function, engine=engine) + self.grouper.transform(function, engine='numba') + + def time_dataframe_cython(self, engine): + def function(values): + return values * 5 + + self.grouper.transform(function, engine='cython') from .pandas_vb_common import setup # noqa: F401 isort:skip From 6c0a5737aacd58e5ab86095a2e1b401e6167c9f5 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 28 Mar 2020 18:47:44 -0700 Subject: [PATCH 20/28] black and fix benchmarks again --- asv_bench/benchmarks/groupby.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 09721d561d738..4ba1ce1440955 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -627,11 +627,7 @@ def time_first(self): class TransformEngine: - - params = ["cython", "numba"] - param_names = ["engine"] - - def setup(self, engine): + def setup(self): N = 10 ** 3 data = DataFrame( {0: ["a", "a", "b", "b", "a"] * N, 1: [1.0, 2.0, 3.0, 4.0, 5.0] * N}, @@ -639,29 +635,29 @@ def setup(self, engine): ) self.grouper = data.groupby(0) - def time_series_numba(self, engine): + def time_series_numba(self): def function(values, index): return values * 5 - self.grouper[0].transform(function, engine='numba') + self.grouper[1].transform(function, engine="numba") - def time_series_cython(self, engine): + def time_series_cython(self): def function(values): return values * 5 - self.grouper[0].transform(function, engine='cython') + self.grouper[1].transform(function, engine="cython") - def time_dataframe_numba(self, engine): + def time_dataframe_numba(self): def function(values, index, columns): return values * 5 - self.grouper.transform(function, engine='numba') + self.grouper.transform(function, engine="numba") - def time_dataframe_cython(self, engine): + def time_dataframe_cython(self): def function(values): return values * 5 - self.grouper.transform(function, engine='cython') + self.grouper.transform(function, engine="cython") from .pandas_vb_common import setup # noqa: F401 isort:skip From 1de2cf1ac4d6b94c0c69b90539ec9a6257b5a0fb Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Wed, 8 Apr 2020 09:44:42 -0700 Subject: [PATCH 21/28] Add more typing to numba utils --- pandas/core/util/numba_.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index 320b579435738..36af1bae82434 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -1,7 +1,7 @@ """Common utilities for Numba operations""" import inspect import types -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple import numpy as np @@ -11,7 +11,7 @@ def check_kwargs_and_nopython( kwargs: Optional[Dict] = None, nopython: Optional[bool] = None -): +) -> None: if kwargs and nopython: raise ValueError( "numba does not support kwargs with nopython=True: " @@ -19,7 +19,9 @@ def check_kwargs_and_nopython( ) -def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None): +def get_jit_arguments( + engine_kwargs: Optional[Dict[str, bool]] = None +) -> Tuple[bool, bool, bool]: """ Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. """ @@ -32,7 +34,9 @@ def get_jit_arguments(engine_kwargs: Optional[Dict[str, bool]] = None): return nopython, nogil, parallel -def jit_user_function(func: Callable, nopython: bool, nogil: bool, parallel: bool): +def jit_user_function( + func: Callable, nopython: bool, nogil: bool, parallel: bool +) -> Callable: """ JIT the user's function given the configurable arguments. """ @@ -60,7 +64,7 @@ def impl(data, *_args): return numba_func -def split_for_numba(arg: FrameOrSeries): +def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Split pandas object into its components as numpy arrays for numba functions. """ @@ -71,7 +75,7 @@ def split_for_numba(arg: FrameOrSeries): return arg.to_numpy(), arg.index.to_numpy(), columns_as_array -def validate_udf(func: Callable, include_columns: bool = False): +def validate_udf(func: Callable, include_columns: bool = False) -> None: """ Validate user defined function for ops when using Numba. From e984283d5e627ed6d5151c3087119f0757752b3a Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sun, 12 Apr 2020 18:00:59 -0700 Subject: [PATCH 22/28] Add more docstrings --- pandas/core/util/numba_.py | 67 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index 36af1bae82434..d215decfa7ba7 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -12,6 +12,25 @@ def check_kwargs_and_nopython( kwargs: Optional[Dict] = None, nopython: Optional[bool] = None ) -> None: + """ + Validate that **kwargs and nopython=True was passed + https://github.com/numba/numba/issues/2916 + + Parameters + ---------- + kwargs : dict, default None + user passed keyword arguments to pass into the JITed function + nopython : bool, default None + nopython parameter + + Returns + ------- + None + + Raises + ------ + ValueError + """ if kwargs and nopython: raise ValueError( "numba does not support kwargs with nopython=True: " @@ -24,6 +43,16 @@ def get_jit_arguments( ) -> Tuple[bool, bool, bool]: """ Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. + + Parameters + ---------- + engine_kwargs : dict, default None + user passed keyword arguments for numba.JIT + + Returns + ------- + (bool, bool, bool) + nopython, nogil, parallel """ if engine_kwargs is None: engine_kwargs = {} @@ -39,6 +68,23 @@ def jit_user_function( ) -> Callable: """ JIT the user's function given the configurable arguments. + + Parameters + ---------- + func : function + user defined function + + nopython : bool + nopython parameter for numba.JIT + nogil : bool + nogil parameter for numba.JIT + parallel : bool + parallel parameter for numba.JIT + + Returns + ------- + function + Numba JITed function """ numba = import_optional_dependency("numba") @@ -67,6 +113,15 @@ def impl(data, *_args): def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Split pandas object into its components as numpy arrays for numba functions. + + Parameters + ---------- + arg : Series or DataFrame + + Returns + ------- + (ndarray, ndarray, ndarray) + values, index, columns """ if getattr(arg, "columns", None) is not None: columns_as_array = arg.columns.to_numpy() @@ -89,6 +144,18 @@ def f(values, index, ...): def f(values, index, columns, ...): ... + + Parameters + ---------- + func : function, default False + user defined function + + include_columns : bool + whether 'columns' should be in the signature + + Returns + ------- + None """ udf_signature = list(inspect.signature(func).parameters.keys()) expected_args = ["values", "index"] From 909e92e5d6183c8358fea964b2c436c98ff561e1 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sun, 12 Apr 2020 18:34:54 -0700 Subject: [PATCH 23/28] Have benchmark contain more groups --- asv_bench/benchmarks/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 4ba1ce1440955..9957de7227d67 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -630,7 +630,7 @@ class TransformEngine: def setup(self): N = 10 ** 3 data = DataFrame( - {0: ["a", "a", "b", "b", "a"] * N, 1: [1.0, 2.0, 3.0, 4.0, 5.0] * N}, + {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N}, columns=[0, 1], ) self.grouper = data.groupby(0) From cd7a0beac5a8f22f298961453d11c26f0692d91c Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 13 Apr 2020 19:05:21 -0700 Subject: [PATCH 24/28] Remove columns as a required argument for udf --- asv_bench/benchmarks/groupby.py | 2 +- pandas/core/groupby/generic.py | 8 +++--- pandas/core/util/numba_.py | 27 +++++--------------- pandas/tests/groupby/transform/test_numba.py | 22 +++++----------- 4 files changed, 18 insertions(+), 41 deletions(-) diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 9957de7227d67..eb637c78806c0 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -648,7 +648,7 @@ def function(values): self.grouper[1].transform(function, engine="cython") def time_dataframe_numba(self): - def function(values, index, columns): + def function(values, index): return values * 5 self.grouper.transform(function, engine="numba") diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 37ef2340caa8a..c007d4920cbe7 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -514,7 +514,7 @@ def _transform_general( for name, group in self: object.__setattr__(group, "name", name) if engine == "numba": - values, index, _ = split_for_numba(group) + values, index = split_for_numba(group) res = numba_func(values, index, *args) if func not in self._numba_func_cache: self._numba_func_cache[func] = numba_func @@ -1396,7 +1396,7 @@ def _transform_general( if engine == "numba": nopython, nogil, parallel = get_jit_arguments(engine_kwargs) check_kwargs_and_nopython(kwargs, nopython) - validate_udf(func, include_columns=True) + validate_udf(func) numba_func = self._numba_func_cache.get( func, jit_user_function(func, nopython, nogil, parallel) ) @@ -1407,8 +1407,8 @@ def _transform_general( object.__setattr__(group, "name", name) if engine == "numba": - values, index, columns = split_for_numba(group) - res = numba_func(values, index, columns, *args) + values, index = split_for_numba(group) + res = numba_func(values, index, *args) if func not in self._numba_func_cache: self._numba_func_cache[func] = numba_func # Return the result as a DataFrame for concatenation later diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index d215decfa7ba7..c5b27b937a05b 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -110,7 +110,7 @@ def impl(data, *_args): return numba_func -def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray]: """ Split pandas object into its components as numpy arrays for numba functions. @@ -120,47 +120,32 @@ def split_for_numba(arg: FrameOrSeries) -> Tuple[np.ndarray, np.ndarray, np.ndar Returns ------- - (ndarray, ndarray, ndarray) - values, index, columns + (ndarray, ndarray) + values, index """ - if getattr(arg, "columns", None) is not None: - columns_as_array = arg.columns.to_numpy() - else: - columns_as_array = None - return arg.to_numpy(), arg.index.to_numpy(), columns_as_array + return arg.to_numpy(), arg.index.to_numpy() -def validate_udf(func: Callable, include_columns: bool = False) -> None: +def validate_udf(func: Callable) -> None: """ Validate user defined function for ops when using Numba. - For routines that pass Series objects, the first signature arguments should include: + The first signature arguments should include: def f(values, index, ...): ... - For routines that pass DataFrame objects, the first signature arguments should - include: - - def f(values, index, columns, ...): - ... - Parameters ---------- func : function, default False user defined function - include_columns : bool - whether 'columns' should be in the signature - Returns ------- None """ udf_signature = list(inspect.signature(func).parameters.keys()) expected_args = ["values", "index"] - if include_columns: - expected_args.append("columns") min_number_args = len(expected_args) if ( len(udf_signature) < min_number_args diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 506a6065bfd19..ee0e0838ef464 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -15,7 +15,7 @@ def incorrect_function(x): {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=["key", "data"], ) - with pytest.raises(ValueError, match=f"The first 3"): + with pytest.raises(ValueError, match=f"The first 2"): data.groupby("key").transform(incorrect_function, engine="numba") with pytest.raises(ValueError, match=f"The first 2"): @@ -43,18 +43,14 @@ def incorrect_function(x, **kwargs): # Filter warnings when parallel=True and the function can't be parallelized by Numba @pytest.mark.parametrize("jit", [True, False]) def test_numba_vs_cython(jit, nogil, parallel, nopython): - def series_func(values, index): - return values + 1 - - def dataframe_func(values, index, columns): + def func(values, index): return values + 1 if jit: # Test accepted jitted functions import numba - series_func = numba.jit(series_func) - dataframe_func = numba.jit(dataframe_func) + func = numba.jit(func) data = DataFrame( {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1], @@ -62,16 +58,12 @@ def dataframe_func(values, index, columns): engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} grouped = data.groupby(0) - result = grouped.transform( - dataframe_func, engine="numba", engine_kwargs=engine_kwargs - ) + result = grouped.transform(func, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_frame_equal(result, expected) - result = grouped[1].transform( - series_func, engine="numba", engine_kwargs=engine_kwargs - ) + result = grouped[1].transform(func, engine="numba", engine_kwargs=engine_kwargs) expected = grouped[1].transform(lambda x: x + 1, engine="cython") tm.assert_series_equal(result, expected) @@ -129,10 +121,10 @@ def series_func_2(values, index): @pytest.mark.parametrize("jit", [True, False]) def test_cache_dataframe(jit, nogil, parallel, nopython): # Test that the functions are cached correctly if we switch functions - def dataframe_func_1(values, index, columns): + def dataframe_func_1(values, index): return values + 1 - def dataframe_func_2(values, index, columns): + def dataframe_func_2(values, index): return values * 5 if jit: From 930466a8e9161a94527f6bc33eec3496b1e553af Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 13 Apr 2020 21:07:44 -0700 Subject: [PATCH 25/28] Simplify tests --- pandas/tests/groupby/transform/test_numba.py | 96 +++++--------------- 1 file changed, 22 insertions(+), 74 deletions(-) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index ee0e0838ef464..2f35edd00a1d9 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -42,7 +42,8 @@ def incorrect_function(x, **kwargs): @pytest.mark.filterwarnings("ignore:\\nThe keyword argument") # Filter warnings when parallel=True and the function can't be parallelized by Numba @pytest.mark.parametrize("jit", [True, False]) -def test_numba_vs_cython(jit, nogil, parallel, nopython): +@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) +def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython): def func(values, index): return values + 1 @@ -57,106 +58,53 @@ def func(values, index): ) engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} grouped = data.groupby(0) + if pandas_obj == "Series": + grouped = grouped[1] result = grouped.transform(func, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x + 1, engine="cython") - tm.assert_frame_equal(result, expected) - - result = grouped[1].transform(func, engine="numba", engine_kwargs=engine_kwargs) - expected = grouped[1].transform(lambda x: x + 1, engine="cython") - - tm.assert_series_equal(result, expected) - - -@td.skip_if_no("numba", "0.46.0") -@pytest.mark.filterwarnings("ignore:\\nThe keyword argument") -# Filter warnings when parallel=True and the function can't be parallelized by Numba -@pytest.mark.parametrize("jit", [True, False]) -def test_cache_series(jit, nogil, parallel, nopython): - # Test that the functions are cached correctly if we switch functions - def series_func_1(values, index): - return values + 1 - - def series_func_2(values, index): - return values * 5 - - if jit: - import numba - - series_func_1 = numba.jit(series_func_1) - series_func_2 = numba.jit(series_func_2) - - data = DataFrame( - {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1], - ) - engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} - grouped = data.groupby(0) - - result = grouped[1].transform( - series_func_1, engine="numba", engine_kwargs=engine_kwargs - ) - expected = grouped[1].transform(lambda x: x + 1, engine="cython") - tm.assert_series_equal(result, expected) - # series_func_1 should be in the cache now - assert series_func_1 in grouped[1]._numba_func_cache - - # Add dataframe_func_2 to the cache - result = grouped[1].transform( - series_func_2, engine="numba", engine_kwargs=engine_kwargs - ) - expected = grouped[1].transform(lambda x: x * 5, engine="cython") - tm.assert_series_equal(result, expected) - - result = grouped[1].transform( - series_func_1, engine="numba", engine_kwargs=engine_kwargs - ) - expected = grouped[1].transform(lambda x: x + 1, engine="cython") - tm.assert_series_equal(result, expected) + tm.assert_equal(result, expected) @td.skip_if_no("numba", "0.46.0") @pytest.mark.filterwarnings("ignore:\\nThe keyword argument") # Filter warnings when parallel=True and the function can't be parallelized by Numba @pytest.mark.parametrize("jit", [True, False]) -def test_cache_dataframe(jit, nogil, parallel, nopython): +@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"]) +def test_cache(jit, pandas_obj, nogil, parallel, nopython): # Test that the functions are cached correctly if we switch functions - def dataframe_func_1(values, index): + def func_1(values, index): return values + 1 - def dataframe_func_2(values, index): + def func_2(values, index): return values * 5 if jit: import numba - dataframe_func_1 = numba.jit(dataframe_func_1) - dataframe_func_2 = numba.jit(dataframe_func_2) + func_1 = numba.jit(func_1) + func_2 = numba.jit(func_2) data = DataFrame( {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1], ) engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} grouped = data.groupby(0) + if pandas_obj == "Series": + grouped = grouped[1] - result = grouped.transform( - dataframe_func_1, engine="numba", engine_kwargs=engine_kwargs - ) + result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x + 1, engine="cython") - tm.assert_frame_equal(result, expected) - # dataframe_func_1 should be in the cache now - assert dataframe_func_1 in grouped._numba_func_cache + tm.assert_equal(result, expected) + # func_1 should be in the cache now + assert func_1 in grouped._numba_func_cache - # Add dataframe_func_2 to the cache - result = grouped.transform( - dataframe_func_2, engine="numba", engine_kwargs=engine_kwargs - ) + # Add func_2 to the cache + result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x * 5, engine="cython") - tm.assert_frame_equal(result, expected) + tm.assert_equal(result, expected) - # This run should use the cached dataframe_func_1 - result = grouped.transform( - dataframe_func_1, engine="numba", engine_kwargs=engine_kwargs - ) + result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x + 1, engine="cython") - tm.assert_frame_equal(result, expected) + tm.assert_equal(result, expected) From 145ca5081241b79cc5bdd197b74a3536ceb87e0b Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 13 Apr 2020 21:11:49 -0700 Subject: [PATCH 26/28] Add one more test and commentary --- pandas/tests/groupby/transform/test_numba.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 2f35edd00a1d9..96078d0aa3662 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -104,7 +104,9 @@ def func_2(values, index): result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x * 5, engine="cython") tm.assert_equal(result, expected) + assert func_2 in grouped._numba_func_cache + # Retest func_1 which should use the cache result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs) expected = grouped.transform(lambda x: x + 1, engine="cython") tm.assert_equal(result, expected) From fc0654d3e34e872f5efdf94328bb91420552b93b Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 13 Apr 2020 21:30:06 -0700 Subject: [PATCH 27/28] Expand docstring --- pandas/core/groupby/groupby.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 873f24b9685e3..f0e7614effe98 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -254,7 +254,36 @@ class providing the base-class of operations. Parameters ---------- f : function - Function to apply to each group + Function to apply to each group. + + Can also accept a Numba JIT function with + ``engine='numba'`` specified. + + If the ``'numba'`` engine is chosen, the function must be + a user defined function with ``values`` and ``index`` as the + first and second arguments respectively in the function signature. + Each group's index will be passed to the user defined function + and optionally available for use. + + .. versionchanged:: 1.1.0 +*args + Positional arguments to pass to func +engine : str, default 'cython' + * ``'cython'`` : Runs the function through C-extensions from cython. + * ``'numba'`` : Runs the function through JIT compiled code from numba. + + .. versionadded:: 1.1.0 +engine_kwargs : dict, default None + * For ``'cython'`` engine, there are no accepted ``engine_kwargs`` + * For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil`` + and ``parallel`` dictionary keys. The values must either be ``True`` or + ``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is + ``{'nopython': True, 'nogil': False, 'parallel': False}`` and will be + applied to the function + + .. versionadded:: 1.1.0 +**kwargs + Keyword arguments to be passed into func. Returns ------- From 9dbded0ac715fbbee6c8405054c57f1caf2e0b70 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Tue, 14 Apr 2020 22:06:16 -0700 Subject: [PATCH 28/28] lint --- pandas/core/groupby/groupby.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f0e7614effe98..154af3981a5ff 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -255,16 +255,16 @@ class providing the base-class of operations. ---------- f : function Function to apply to each group. - - Can also accept a Numba JIT function with + + Can also accept a Numba JIT function with ``engine='numba'`` specified. - + If the ``'numba'`` engine is chosen, the function must be - a user defined function with ``values`` and ``index`` as the + a user defined function with ``values`` and ``index`` as the first and second arguments respectively in the function signature. Each group's index will be passed to the user defined function and optionally available for use. - + .. versionchanged:: 1.1.0 *args Positional arguments to pass to func