From f112fc07ea0c52b9ccf25c1346dae87da1eb2046 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 7 Sep 2020 15:51:40 -0700 Subject: [PATCH 1/9] Add pathway for groupby transform --- pandas/core/groupby/generic.py | 12 ++++++ pandas/core/groupby/groupby.py | 29 ++++++++++++++ pandas/core/groupby/numba_.py | 69 +++++++++++++++++++++++++++++++++- 3 files changed, 108 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 72003eab24b29..be5d3d648807f 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1362,6 +1362,18 @@ def _transform_general( @Appender(_transform_template) def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): + if maybe_use_numba(engine): + if not callable(func): + raise NotImplementedError( + "Numba engine can only be used with a single function." + ) + with _group_selection_context(self): + data = self._selected_obj + result, index = self._transform_with_numba( + data, func, *args, engine_kwargs=engine_kwargs, **kwargs + ) + return self.obj._constructor(result, index=index, columns=data.columns) + # optimized transforms func = self._get_cython_func(func) or func diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 6ef2e67030881..27dbdc2f3e75f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1056,6 +1056,35 @@ def _cython_agg_general( return self._wrap_aggregated_output(output, index=self.grouper.result_index) + def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs): + """""" + group_keys = self.grouper._get_group_keys() + labels, _, n_groups = self.grouper.group_info + sorted_index = get_group_index_sorter(labels, n_groups) + sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False) + sorted_data = data.take(sorted_index, axis=self.axis).to_numpy() + starts, ends = lib.generate_slices(sorted_labels, n_groups) + cache_key = (func, "groupby_agg") + if cache_key in NUMBA_FUNC_CACHE: + # Return an already compiled version of roll_apply if available + numba_transform_func = NUMBA_FUNC_CACHE[cache_key] + else: + numba_transform_func = numba_.generate_numba_transform_func( + tuple(args), kwargs, func, engine_kwargs + ) + breakpoint() + result = numba_transform_func( + sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns) + ) + if cache_key not in NUMBA_FUNC_CACHE: + NUMBA_FUNC_CACHE[cache_key] = numba_transform_func + + if self.grouper.nkeys > 1: + index = MultiIndex.from_tuples(group_keys, names=self.grouper.names) + else: + index = Index(group_keys, name=self.grouper.names[0]) + return result, index + def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs): """ Perform groupby aggregation routine with the numba engine. diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index aebe60f797fcd..9b4faf8dcf577 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -153,7 +153,7 @@ def generate_numba_agg_func( loop_range = range @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) - def group_apply( + def group_agg( values: np.ndarray, index: np.ndarray, begin: np.ndarray, @@ -169,4 +169,69 @@ def group_apply( result[i, j] = numba_func(group, group_index, *args) return result - return group_apply + return group_agg + + +def generate_numba_transform_func( + args: Tuple, + kwargs: Dict[str, Any], + func: Callable[..., Scalar], + engine_kwargs: Optional[Dict[str, bool]], +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]: + """ + Generate a numba jitted agg function specified by values from engine_kwargs. + + 1. jit the user's function + 2. Return a groupby agg function with the jitted function inline + + Configurations specified in engine_kwargs apply to both the user's + function _AND_ the rolling apply function. + + Parameters + ---------- + args : tuple + *args to be passed into the function + kwargs : dict + **kwargs to be passed into the function + func : function + function to be applied to each window and will be JITed + engine_kwargs : dict + dictionary of arguments to be passed into numba.jit + + Returns + ------- + Numba function + """ + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + + check_kwargs_and_nopython(kwargs, nopython) + + validate_udf(func) + + numba_func = jit_user_function(func, nopython, nogil, parallel) + + numba = import_optional_dependency("numba") + + if parallel: + loop_range = numba.prange + else: + loop_range = range + + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) + def group_transform( + values: np.ndarray, + index: np.ndarray, + begin: np.ndarray, + end: np.ndarray, + num_groups: int, + num_columns: int, + ) -> np.ndarray: + result = np.empty((num_groups, num_columns)) + for i in loop_range(num_groups): + group_index = index[begin[i] : end[i]] + for j in loop_range(num_columns): + group = values[begin[i] : end[i], j] + result[begin[i]: end[i], j] = numba_func(group, group_index, *args) + return result + + return group_transform \ No newline at end of file From edda97d456610787fb7b4d3d3dc7cea87438d5b0 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 7 Sep 2020 22:35:05 -0700 Subject: [PATCH 2/9] Add path for groupby transform series --- pandas/core/groupby/generic.py | 13 +++++++++++++ pandas/core/groupby/groupby.py | 17 +++++++++-------- pandas/core/groupby/numba_.py | 6 +++--- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index be5d3d648807f..a16af1dfa3bf5 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -489,6 +489,19 @@ def _aggregate_named(self, func, *args, **kwargs): @Substitution(klass="Series") @Appender(_transform_template) def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): + + if maybe_use_numba(engine): + if not callable(func): + raise NotImplementedError( + "Numba engine can only be used with a single function." + ) + with _group_selection_context(self): + data = self._selected_obj + result, index = self._aggregate_with_numba( + data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs + ) + return self.obj._constructor(result.ravel(), index=index, name=data.name) + func = self._get_cython_func(func) or func if not isinstance(func, str): diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 27dbdc2f3e75f..030caf60f9bb4 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1057,14 +1057,20 @@ def _cython_agg_general( return self._wrap_aggregated_output(output, index=self.grouper.result_index) def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs): - """""" + """ + Perform groupby transform routine with the numba engine. + + This routine mimics the data splitting routine of the DataSplitter class + to generate the indices of each group in the sorted data and then passes the + data and indices into a Numba jitted function. + """ group_keys = self.grouper._get_group_keys() labels, _, n_groups = self.grouper.group_info sorted_index = get_group_index_sorter(labels, n_groups) sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False) sorted_data = data.take(sorted_index, axis=self.axis).to_numpy() starts, ends = lib.generate_slices(sorted_labels, n_groups) - cache_key = (func, "groupby_agg") + cache_key = (func, "groupby_transform") if cache_key in NUMBA_FUNC_CACHE: # Return an already compiled version of roll_apply if available numba_transform_func = NUMBA_FUNC_CACHE[cache_key] @@ -1072,18 +1078,13 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) numba_transform_func = numba_.generate_numba_transform_func( tuple(args), kwargs, func, engine_kwargs ) - breakpoint() result = numba_transform_func( sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns) ) if cache_key not in NUMBA_FUNC_CACHE: NUMBA_FUNC_CACHE[cache_key] = numba_transform_func - if self.grouper.nkeys > 1: - index = MultiIndex.from_tuples(group_keys, names=self.grouper.names) - else: - index = Index(group_keys, name=self.grouper.names[0]) - return result, index + return result, sorted_index def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs): """ diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index 9b4faf8dcf577..7727a826da9f2 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -226,12 +226,12 @@ def group_transform( num_groups: int, num_columns: int, ) -> np.ndarray: - result = np.empty((num_groups, num_columns)) + result = np.empty((len(values), num_columns)) for i in loop_range(num_groups): group_index = index[begin[i] : end[i]] for j in loop_range(num_columns): group = values[begin[i] : end[i], j] - result[begin[i]: end[i], j] = numba_func(group, group_index, *args) + result[begin[i] : end[i], j] = numba_func(group, group_index, *args) return result - return group_transform \ No newline at end of file + return group_transform From b1084808f9b47b247a9122fea607e2637f3f895c Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 7 Sep 2020 22:40:28 -0700 Subject: [PATCH 3/9] Roll back old groupby transform implementation --- pandas/core/groupby/generic.py | 51 ++++++++++------------------------ 1 file changed, 14 insertions(+), 37 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index a16af1dfa3bf5..4ea6f7841b179 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -497,7 +497,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): ) with _group_selection_context(self): data = self._selected_obj - result, index = self._aggregate_with_numba( + result, index = self._transform_with_numba( data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs ) return self.obj._constructor(result.ravel(), index=index, name=data.name) @@ -505,9 +505,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general( - func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs - ) + return self._transform_general(func, *args, **kwargs) elif func not in base.transform_kernel_allowlist: msg = f"'{func}' is not a valid function name for transform(name)" @@ -1304,42 +1302,25 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False): return self._reindex_output(result) - def _transform_general( - self, func, *args, engine="cython", engine_kwargs=None, **kwargs - ): + def _transform_general(self, func, *args, **kwargs): from pandas.core.reshape.concat import concat applied = [] obj = self._obj_with_exclusions gen = self.grouper.get_iterator(obj, axis=self.axis) - if maybe_use_numba(engine): - numba_func, cache_key = generate_numba_func( - func, engine_kwargs, kwargs, "groupby_transform" - ) - else: - fast_path, slow_path = self._define_paths(func, *args, **kwargs) + fast_path, slow_path = self._define_paths(func, *args, **kwargs) for name, group in gen: object.__setattr__(group, "name", name) - if maybe_use_numba(engine): - values, index = split_for_numba(group) - res = numba_func(values, index, *args) - if cache_key not in NUMBA_FUNC_CACHE: - NUMBA_FUNC_CACHE[cache_key] = numba_func - # Return the result as a DataFrame for concatenation later - res = self.obj._constructor( - res, index=group.index, columns=group.columns - ) - else: - # Try slow path and fast path. - try: - path, res = self._choose_path(fast_path, slow_path, group) - except TypeError: - return self._transform_item_by_item(obj, fast_path) - except ValueError as err: - msg = "transform must return a scalar value for each group" - raise ValueError(msg) from err + # Try slow path and fast path. + try: + path, res = self._choose_path(fast_path, slow_path, group) + except TypeError: + return self._transform_item_by_item(obj, fast_path) + except ValueError as err: + msg = "transform must return a scalar value for each group" + raise ValueError(msg) from err if isinstance(res, Series): @@ -1391,9 +1372,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general( - func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs - ) + return self._transform_general(func, *args, **kwargs) elif func not in base.transform_kernel_allowlist: msg = f"'{func}' is not a valid function name for transform(name)" @@ -1419,9 +1398,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): ): return self._transform_fast(result) - return self._transform_general( - func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs - ) + return self._transform_general(func, *args, **kwargs) def _transform_fast(self, result: DataFrame) -> DataFrame: """ From d92d2ad6c56285258e48a67ead57db01c33db35c Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Mon, 7 Sep 2020 23:59:32 -0700 Subject: [PATCH 4/9] Fix docstring and add whatsnew --- doc/source/whatsnew/v1.2.0.rst | 2 +- pandas/core/groupby/numba_.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 2afa1f1a6199e..fcb5125881c8c 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -206,7 +206,7 @@ Performance improvements ~~~~~~~~~~~~~~~~~~~~~~~~ - Performance improvement in :meth:`GroupBy.agg` with the ``numba`` engine (:issue:`35759`) -- +- Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:``) .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index 7727a826da9f2..a2dfcd7bddd53 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -179,7 +179,7 @@ def generate_numba_transform_func( engine_kwargs: Optional[Dict[str, bool]], ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]: """ - Generate a numba jitted agg function specified by values from engine_kwargs. + Generate a numba jitted transform function specified by values from engine_kwargs. 1. jit the user's function 2. Return a groupby agg function with the jitted function inline From 4679501330a3b7100992e2076efcb1354635c312 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Tue, 8 Sep 2020 20:59:42 -0700 Subject: [PATCH 5/9] Fix resulting data, add test for multiple functions --- pandas/core/groupby/generic.py | 10 ++++++---- pandas/core/groupby/groupby.py | 6 +++--- pandas/tests/groupby/transform/test_numba.py | 16 ++++++++++++++++ 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 4ea6f7841b179..073f1f09d2edb 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -497,10 +497,12 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): ) with _group_selection_context(self): data = self._selected_obj - result, index = self._transform_with_numba( + result = self._transform_with_numba( data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs ) - return self.obj._constructor(result.ravel(), index=index, name=data.name) + return self.obj._constructor( + result.ravel(), index=data.index, name=data.name + ) func = self._get_cython_func(func) or func @@ -1363,10 +1365,10 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): ) with _group_selection_context(self): data = self._selected_obj - result, index = self._transform_with_numba( + result = self._transform_with_numba( data, func, *args, engine_kwargs=engine_kwargs, **kwargs ) - return self.obj._constructor(result, index=index, columns=data.columns) + return self.obj._constructor(result, index=data.index, columns=data.columns) # optimized transforms func = self._get_cython_func(func) or func diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 030caf60f9bb4..b537fc9842ac4 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1072,7 +1072,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) starts, ends = lib.generate_slices(sorted_labels, n_groups) cache_key = (func, "groupby_transform") if cache_key in NUMBA_FUNC_CACHE: - # Return an already compiled version of roll_apply if available numba_transform_func = NUMBA_FUNC_CACHE[cache_key] else: numba_transform_func = numba_.generate_numba_transform_func( @@ -1084,7 +1083,9 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) if cache_key not in NUMBA_FUNC_CACHE: NUMBA_FUNC_CACHE[cache_key] = numba_transform_func - return result, sorted_index + # result values needs to be resorted to their original positions since we + # evaluated the data sorted by group + return result.take(np.argsort(sorted_index), axis=0) def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs): """ @@ -1102,7 +1103,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) starts, ends = lib.generate_slices(sorted_labels, n_groups) cache_key = (func, "groupby_agg") if cache_key in NUMBA_FUNC_CACHE: - # Return an already compiled version of roll_apply if available numba_agg_func = NUMBA_FUNC_CACHE[cache_key] else: numba_agg_func = numba_.generate_numba_agg_func( diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 87723cd7c8f50..fcaa5ab13599a 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -127,3 +127,19 @@ def func_1(values, index): with option_context("compute.use_numba", True): result = grouped.transform(func_1, engine=None) tm.assert_frame_equal(expected, result) + + +@td.skip_if_no("numba", "0.46.0") +@pytest.mark.parametrize( + "agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}], +) +def test_multifunc_notimplimented(agg_func): + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] + ) + grouped = data.groupby(0) + with pytest.raises(NotImplementedError, match="Numba engine can"): + grouped.transform(agg_func, engine="numba") + + with pytest.raises(NotImplementedError, match="Numba engine can"): + grouped[1].transform(agg_func, engine="numba") From 944e8e04a6796e227b8353cb04ef38d52537ca19 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Tue, 8 Sep 2020 21:04:11 -0700 Subject: [PATCH 6/9] Update issue number --- doc/source/whatsnew/v1.2.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index fcb5125881c8c..781833fd775dc 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -206,7 +206,7 @@ Performance improvements ~~~~~~~~~~~~~~~~~~~~~~~~ - Performance improvement in :meth:`GroupBy.agg` with the ``numba`` engine (:issue:`35759`) -- Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:``) +- Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:`36240`) .. --------------------------------------------------------------------------- From f8e6fdbada9e01c77ccae5f8c78bae90d3b41904 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Wed, 9 Sep 2020 09:25:55 -0700 Subject: [PATCH 7/9] Fix linting error --- pandas/core/groupby/generic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 14ed20b51da0a..cb1a0593bc206 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -495,7 +495,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): raise NotImplementedError( "Numba engine can only be used with a single function." ) - with _group_selection_context(self): + with group_selection_context(self): data = self._selected_obj result = self._transform_with_numba( data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs @@ -1363,7 +1363,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): raise NotImplementedError( "Numba engine can only be used with a single function." ) - with _group_selection_context(self): + with group_selection_context(self): data = self._selected_obj result = self._transform_with_numba( data, func, *args, engine_kwargs=engine_kwargs, **kwargs From 97fc5c9ef9971a75116de1784f321d8a6ce1cc96 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 12 Sep 2020 18:18:13 -0700 Subject: [PATCH 8/9] Move callable check to private method --- pandas/core/groupby/generic.py | 16 ---------------- pandas/core/groupby/groupby.py | 8 ++++++++ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index cc6eb91e69179..ffd756bed43b6 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -226,10 +226,6 @@ def apply(self, func, *args, **kwargs): def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): - if not callable(func): - raise NotImplementedError( - "Numba engine can only be used with a single function." - ) with group_selection_context(self): data = self._selected_obj result, index = self._aggregate_with_numba( @@ -491,10 +487,6 @@ def _aggregate_named(self, func, *args, **kwargs): def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): - if not callable(func): - raise NotImplementedError( - "Numba engine can only be used with a single function." - ) with group_selection_context(self): data = self._selected_obj result = self._transform_with_numba( @@ -951,10 +943,6 @@ class DataFrameGroupBy(GroupBy[DataFrame]): def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): - if not callable(func): - raise NotImplementedError( - "Numba engine can only be used with a single function." - ) with group_selection_context(self): data = self._selected_obj result, index = self._aggregate_with_numba( @@ -1358,10 +1346,6 @@ def _transform_general(self, func, *args, **kwargs): def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): - if not callable(func): - raise NotImplementedError( - "Numba engine can only be used with a single function." - ) with group_selection_context(self): data = self._selected_obj result = self._transform_with_numba( diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index e34cd8614904c..30bd53a3ddff1 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1064,6 +1064,10 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) to generate the indices of each group in the sorted data and then passes the data and indices into a Numba jitted function. """ + if not callable(func): + raise NotImplementedError( + "Numba engine can only be used with a single function." + ) group_keys = self.grouper._get_group_keys() labels, _, n_groups = self.grouper.group_info sorted_index = get_group_index_sorter(labels, n_groups) @@ -1095,6 +1099,10 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) to generate the indices of each group in the sorted data and then passes the data and indices into a Numba jitted function. """ + if not callable(func): + raise NotImplementedError( + "Numba engine can only be used with a single function." + ) group_keys = self.grouper._get_group_keys() labels, _, n_groups = self.grouper.group_info sorted_index = get_group_index_sorter(labels, n_groups) From 0ae610155afb28ccf740128fcbc9aa3860993bc7 Mon Sep 17 00:00:00 2001 From: Matt Roeschke Date: Sat, 12 Sep 2020 19:06:05 -0700 Subject: [PATCH 9/9] enhance benchmarks for cython engine --- asv_bench/benchmarks/groupby.py | 46 +++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 5ffda03fad80f..bda3ab71d1a00 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -627,33 +627,42 @@ def time_first(self): class TransformEngine: - def setup(self): + + param_names = ["parallel"] + params = [[True, False]] + + def setup(self, parallel): N = 10 ** 3 data = DataFrame( {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N}, columns=[0, 1], ) + self.parallel = parallel self.grouper = data.groupby(0) - def time_series_numba(self): + def time_series_numba(self, parallel): def function(values, index): return values * 5 - self.grouper[1].transform(function, engine="numba") + self.grouper[1].transform( + function, engine="numba", engine_kwargs={"parallel": self.parallel} + ) - def time_series_cython(self): + def time_series_cython(self, parallel): def function(values): return values * 5 self.grouper[1].transform(function, engine="cython") - def time_dataframe_numba(self): + def time_dataframe_numba(self, parallel): def function(values, index): return values * 5 - self.grouper.transform(function, engine="numba") + self.grouper.transform( + function, engine="numba", engine_kwargs={"parallel": self.parallel} + ) - def time_dataframe_cython(self): + def time_dataframe_cython(self, parallel): def function(values): return values * 5 @@ -661,15 +670,20 @@ def function(values): class AggEngine: - def setup(self): + + param_names = ["parallel"] + params = [[True, False]] + + def setup(self, parallel): N = 10 ** 3 data = DataFrame( {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N}, columns=[0, 1], ) + self.parallel = parallel self.grouper = data.groupby(0) - def time_series_numba(self): + def time_series_numba(self, parallel): def function(values, index): total = 0 for i, value in enumerate(values): @@ -679,9 +693,11 @@ def function(values, index): total += value * 2 return total - self.grouper[1].agg(function, engine="numba") + self.grouper[1].agg( + function, engine="numba", engine_kwargs={"parallel": self.parallel} + ) - def time_series_cython(self): + def time_series_cython(self, parallel): def function(values): total = 0 for i, value in enumerate(values): @@ -693,7 +709,7 @@ def function(values): self.grouper[1].agg(function, engine="cython") - def time_dataframe_numba(self): + def time_dataframe_numba(self, parallel): def function(values, index): total = 0 for i, value in enumerate(values): @@ -703,9 +719,11 @@ def function(values, index): total += value * 2 return total - self.grouper.agg(function, engine="numba") + self.grouper.agg( + function, engine="numba", engine_kwargs={"parallel": self.parallel} + ) - def time_dataframe_cython(self): + def time_dataframe_cython(self, parallel): def function(values): total = 0 for i, value in enumerate(values):