diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 5ffda03fad80f..bda3ab71d1a00 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -627,33 +627,42 @@ def time_first(self): class TransformEngine: - def setup(self): + + param_names = ["parallel"] + params = [[True, False]] + + def setup(self, parallel): N = 10 ** 3 data = DataFrame( {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N}, columns=[0, 1], ) + self.parallel = parallel self.grouper = data.groupby(0) - def time_series_numba(self): + def time_series_numba(self, parallel): def function(values, index): return values * 5 - self.grouper[1].transform(function, engine="numba") + self.grouper[1].transform( + function, engine="numba", engine_kwargs={"parallel": self.parallel} + ) - def time_series_cython(self): + def time_series_cython(self, parallel): def function(values): return values * 5 self.grouper[1].transform(function, engine="cython") - def time_dataframe_numba(self): + def time_dataframe_numba(self, parallel): def function(values, index): return values * 5 - self.grouper.transform(function, engine="numba") + self.grouper.transform( + function, engine="numba", engine_kwargs={"parallel": self.parallel} + ) - def time_dataframe_cython(self): + def time_dataframe_cython(self, parallel): def function(values): return values * 5 @@ -661,15 +670,20 @@ def function(values): class AggEngine: - def setup(self): + + param_names = ["parallel"] + params = [[True, False]] + + def setup(self, parallel): N = 10 ** 3 data = DataFrame( {0: [str(i) for i in range(100)] * N, 1: list(range(100)) * N}, columns=[0, 1], ) + self.parallel = parallel self.grouper = data.groupby(0) - def time_series_numba(self): + def time_series_numba(self, parallel): def function(values, index): total = 0 for i, value in enumerate(values): @@ -679,9 +693,11 @@ def function(values, index): total += value * 2 return total - self.grouper[1].agg(function, engine="numba") + self.grouper[1].agg( + function, engine="numba", engine_kwargs={"parallel": self.parallel} + ) - def time_series_cython(self): + def time_series_cython(self, parallel): def function(values): total = 0 for i, value in enumerate(values): @@ -693,7 +709,7 @@ def function(values): self.grouper[1].agg(function, engine="cython") - def time_dataframe_numba(self): + def time_dataframe_numba(self, parallel): def function(values, index): total = 0 for i, value in enumerate(values): @@ -703,9 +719,11 @@ def function(values, index): total += value * 2 return total - self.grouper.agg(function, engine="numba") + self.grouper.agg( + function, engine="numba", engine_kwargs={"parallel": self.parallel} + ) - def time_dataframe_cython(self): + def time_dataframe_cython(self, parallel): def function(values): total = 0 for i, value in enumerate(values): diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index e577a8f26bd12..c2bfb2f72326a 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -207,7 +207,7 @@ Performance improvements - Performance improvements when creating Series with dtype `str` or :class:`StringDtype` from array with many string elements (:issue:`36304`) - Performance improvement in :meth:`GroupBy.agg` with the ``numba`` engine (:issue:`35759`) -- +- Performance improvement in :meth:`GroupBy.transform` with the ``numba`` engine (:issue:`36240`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 1552256468ad2..ffd756bed43b6 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -226,10 +226,6 @@ def apply(self, func, *args, **kwargs): def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): - if not callable(func): - raise NotImplementedError( - "Numba engine can only be used with a single function." - ) with group_selection_context(self): data = self._selected_obj result, index = self._aggregate_with_numba( @@ -489,12 +485,21 @@ def _aggregate_named(self, func, *args, **kwargs): @Substitution(klass="Series") @Appender(_transform_template) def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): + + if maybe_use_numba(engine): + with group_selection_context(self): + data = self._selected_obj + result = self._transform_with_numba( + data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs + ) + return self.obj._constructor( + result.ravel(), index=data.index, name=data.name + ) + func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general( - func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs - ) + return self._transform_general(func, *args, **kwargs) elif func not in base.transform_kernel_allowlist: msg = f"'{func}' is not a valid function name for transform(name)" @@ -938,10 +943,6 @@ class DataFrameGroupBy(GroupBy[DataFrame]): def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs): if maybe_use_numba(engine): - if not callable(func): - raise NotImplementedError( - "Numba engine can only be used with a single function." - ) with group_selection_context(self): data = self._selected_obj result, index = self._aggregate_with_numba( @@ -1290,42 +1291,25 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False): return self._reindex_output(result) - def _transform_general( - self, func, *args, engine="cython", engine_kwargs=None, **kwargs - ): + def _transform_general(self, func, *args, **kwargs): from pandas.core.reshape.concat import concat applied = [] obj = self._obj_with_exclusions gen = self.grouper.get_iterator(obj, axis=self.axis) - if maybe_use_numba(engine): - numba_func, cache_key = generate_numba_func( - func, engine_kwargs, kwargs, "groupby_transform" - ) - else: - fast_path, slow_path = self._define_paths(func, *args, **kwargs) + fast_path, slow_path = self._define_paths(func, *args, **kwargs) for name, group in gen: object.__setattr__(group, "name", name) - if maybe_use_numba(engine): - values, index = split_for_numba(group) - res = numba_func(values, index, *args) - if cache_key not in NUMBA_FUNC_CACHE: - NUMBA_FUNC_CACHE[cache_key] = numba_func - # Return the result as a DataFrame for concatenation later - res = self.obj._constructor( - res, index=group.index, columns=group.columns - ) - else: - # Try slow path and fast path. - try: - path, res = self._choose_path(fast_path, slow_path, group) - except TypeError: - return self._transform_item_by_item(obj, fast_path) - except ValueError as err: - msg = "transform must return a scalar value for each group" - raise ValueError(msg) from err + # Try slow path and fast path. + try: + path, res = self._choose_path(fast_path, slow_path, group) + except TypeError: + return self._transform_item_by_item(obj, fast_path) + except ValueError as err: + msg = "transform must return a scalar value for each group" + raise ValueError(msg) from err if isinstance(res, Series): @@ -1361,13 +1345,19 @@ def _transform_general( @Appender(_transform_template) def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): + if maybe_use_numba(engine): + with group_selection_context(self): + data = self._selected_obj + result = self._transform_with_numba( + data, func, *args, engine_kwargs=engine_kwargs, **kwargs + ) + return self.obj._constructor(result, index=data.index, columns=data.columns) + # optimized transforms func = self._get_cython_func(func) or func if not isinstance(func, str): - return self._transform_general( - func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs - ) + return self._transform_general(func, *args, **kwargs) elif func not in base.transform_kernel_allowlist: msg = f"'{func}' is not a valid function name for transform(name)" @@ -1393,9 +1383,7 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs): ): return self._transform_fast(result) - return self._transform_general( - func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs - ) + return self._transform_general(func, *args, **kwargs) def _transform_fast(self, result: DataFrame) -> DataFrame: """ diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 8a55d438cf8d4..30bd53a3ddff1 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1056,6 +1056,41 @@ def _cython_agg_general( return self._wrap_aggregated_output(output, index=self.grouper.result_index) + def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs): + """ + Perform groupby transform routine with the numba engine. + + This routine mimics the data splitting routine of the DataSplitter class + to generate the indices of each group in the sorted data and then passes the + data and indices into a Numba jitted function. + """ + if not callable(func): + raise NotImplementedError( + "Numba engine can only be used with a single function." + ) + group_keys = self.grouper._get_group_keys() + labels, _, n_groups = self.grouper.group_info + sorted_index = get_group_index_sorter(labels, n_groups) + sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False) + sorted_data = data.take(sorted_index, axis=self.axis).to_numpy() + starts, ends = lib.generate_slices(sorted_labels, n_groups) + cache_key = (func, "groupby_transform") + if cache_key in NUMBA_FUNC_CACHE: + numba_transform_func = NUMBA_FUNC_CACHE[cache_key] + else: + numba_transform_func = numba_.generate_numba_transform_func( + tuple(args), kwargs, func, engine_kwargs + ) + result = numba_transform_func( + sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns) + ) + if cache_key not in NUMBA_FUNC_CACHE: + NUMBA_FUNC_CACHE[cache_key] = numba_transform_func + + # result values needs to be resorted to their original positions since we + # evaluated the data sorted by group + return result.take(np.argsort(sorted_index), axis=0) + def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs): """ Perform groupby aggregation routine with the numba engine. @@ -1064,6 +1099,10 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) to generate the indices of each group in the sorted data and then passes the data and indices into a Numba jitted function. """ + if not callable(func): + raise NotImplementedError( + "Numba engine can only be used with a single function." + ) group_keys = self.grouper._get_group_keys() labels, _, n_groups = self.grouper.group_info sorted_index = get_group_index_sorter(labels, n_groups) @@ -1072,7 +1111,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) starts, ends = lib.generate_slices(sorted_labels, n_groups) cache_key = (func, "groupby_agg") if cache_key in NUMBA_FUNC_CACHE: - # Return an already compiled version of roll_apply if available numba_agg_func = NUMBA_FUNC_CACHE[cache_key] else: numba_agg_func = numba_.generate_numba_agg_func( diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index aebe60f797fcd..a2dfcd7bddd53 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -153,7 +153,7 @@ def generate_numba_agg_func( loop_range = range @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) - def group_apply( + def group_agg( values: np.ndarray, index: np.ndarray, begin: np.ndarray, @@ -169,4 +169,69 @@ def group_apply( result[i, j] = numba_func(group, group_index, *args) return result - return group_apply + return group_agg + + +def generate_numba_transform_func( + args: Tuple, + kwargs: Dict[str, Any], + func: Callable[..., Scalar], + engine_kwargs: Optional[Dict[str, bool]], +) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]: + """ + Generate a numba jitted transform function specified by values from engine_kwargs. + + 1. jit the user's function + 2. Return a groupby agg function with the jitted function inline + + Configurations specified in engine_kwargs apply to both the user's + function _AND_ the rolling apply function. + + Parameters + ---------- + args : tuple + *args to be passed into the function + kwargs : dict + **kwargs to be passed into the function + func : function + function to be applied to each window and will be JITed + engine_kwargs : dict + dictionary of arguments to be passed into numba.jit + + Returns + ------- + Numba function + """ + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + + check_kwargs_and_nopython(kwargs, nopython) + + validate_udf(func) + + numba_func = jit_user_function(func, nopython, nogil, parallel) + + numba = import_optional_dependency("numba") + + if parallel: + loop_range = numba.prange + else: + loop_range = range + + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) + def group_transform( + values: np.ndarray, + index: np.ndarray, + begin: np.ndarray, + end: np.ndarray, + num_groups: int, + num_columns: int, + ) -> np.ndarray: + result = np.empty((len(values), num_columns)) + for i in loop_range(num_groups): + group_index = index[begin[i] : end[i]] + for j in loop_range(num_columns): + group = values[begin[i] : end[i], j] + result[begin[i] : end[i], j] = numba_func(group, group_index, *args) + return result + + return group_transform diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 87723cd7c8f50..fcaa5ab13599a 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -127,3 +127,19 @@ def func_1(values, index): with option_context("compute.use_numba", True): result = grouped.transform(func_1, engine=None) tm.assert_frame_equal(expected, result) + + +@td.skip_if_no("numba", "0.46.0") +@pytest.mark.parametrize( + "agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}], +) +def test_multifunc_notimplimented(agg_func): + data = DataFrame( + {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1] + ) + grouped = data.groupby(0) + with pytest.raises(NotImplementedError, match="Numba engine can"): + grouped.transform(agg_func, engine="numba") + + with pytest.raises(NotImplementedError, match="Numba engine can"): + grouped[1].transform(agg_func, engine="numba")