From efb840e4038e08bb0a602bf34b20a989aebfa644 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Mon, 24 May 2021 21:56:11 -0700 Subject: [PATCH] BUG: groupby.transform/agg caching *args with numba engine --- doc/source/whatsnew/v1.3.0.rst | 1 + pandas/core/groupby/groupby.py | 22 ++++++++++++++------ pandas/core/groupby/numba_.py | 16 +++++++------- pandas/tests/groupby/aggregate/test_numba.py | 19 +++++++++++++++++ pandas/tests/groupby/transform/test_numba.py | 18 ++++++++++++++++ 5 files changed, 62 insertions(+), 14 deletions(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 258e391b9220c..6d1a6a4e96b33 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -988,6 +988,7 @@ Groupby/resample/rolling - Bug in :meth:`DataFrameGroupBy.__getitem__` with non-unique columns incorrectly returning a malformed :class:`SeriesGroupBy` instead of :class:`DataFrameGroupBy` (:issue:`41427`) - Bug in :meth:`DataFrameGroupBy.transform` with non-unique columns incorrectly raising ``AttributeError`` (:issue:`41427`) - Bug in :meth:`Resampler.apply` with non-unique columns incorrectly dropping duplicated columns (:issue:`41445`) +- Bug in :meth:`DataFrameGroupBy.transform` and :meth:`DataFrameGroupBy.agg` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`41647`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index b27eb4bb8f325..1c0a3dcc1e1db 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1131,10 +1131,16 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) group_keys = self.grouper._get_group_keys() numba_transform_func = numba_.generate_numba_transform_func( - tuple(args), kwargs, func, engine_kwargs + kwargs, func, engine_kwargs ) result = numba_transform_func( - sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns) + sorted_data, + sorted_index, + starts, + ends, + len(group_keys), + len(data.columns), + *args, ) cache_key = (func, "groupby_transform") @@ -1157,11 +1163,15 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs) starts, ends, sorted_index, sorted_data = self._numba_prep(func, data) group_keys = self.grouper._get_group_keys() - numba_agg_func = numba_.generate_numba_agg_func( - tuple(args), kwargs, func, engine_kwargs - ) + numba_agg_func = numba_.generate_numba_agg_func(kwargs, func, engine_kwargs) result = numba_agg_func( - sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns) + sorted_data, + sorted_index, + starts, + ends, + len(group_keys), + len(data.columns), + *args, ) cache_key = (func, "groupby_agg") diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index 26070fcb5e89c..ad78280c5d835 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -56,11 +56,12 @@ def f(values, index, ...): def generate_numba_agg_func( - args: tuple, kwargs: dict[str, Any], func: Callable[..., Scalar], engine_kwargs: dict[str, bool] | None, -) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]: +) -> Callable[ + [np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int, Any], np.ndarray +]: """ Generate a numba jitted agg function specified by values from engine_kwargs. @@ -72,8 +73,6 @@ def generate_numba_agg_func( Parameters ---------- - args : tuple - *args to be passed into the function kwargs : dict **kwargs to be passed into the function func : function @@ -103,6 +102,7 @@ def group_agg( end: np.ndarray, num_groups: int, num_columns: int, + *args: Any, ) -> np.ndarray: result = np.empty((num_groups, num_columns)) for i in numba.prange(num_groups): @@ -116,11 +116,12 @@ def group_agg( def generate_numba_transform_func( - args: tuple, kwargs: dict[str, Any], func: Callable[..., np.ndarray], engine_kwargs: dict[str, bool] | None, -) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]: +) -> Callable[ + [np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int, Any], np.ndarray +]: """ Generate a numba jitted transform function specified by values from engine_kwargs. @@ -132,8 +133,6 @@ def generate_numba_transform_func( Parameters ---------- - args : tuple - *args to be passed into the function kwargs : dict **kwargs to be passed into the function func : function @@ -163,6 +162,7 @@ def group_transform( end: np.ndarray, num_groups: int, num_columns: int, + *args: Any, ) -> np.ndarray: result = np.empty((len(values), num_columns)) for i in numba.prange(num_groups): diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index 6de81d03ca418..ba2d6eeb287c0 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -6,7 +6,9 @@ from pandas import ( DataFrame, + Index, NamedAgg, + Series, option_context, ) import pandas._testing as tm @@ -154,3 +156,20 @@ def test_multifunc_notimplimented(agg_func): with pytest.raises(NotImplementedError, match="Numba engine can"): grouped[1].agg(agg_func, engine="numba") + + +@td.skip_if_no("numba", "0.46.0") +def test_args_not_cached(): + # GH 41647 + def sum_last(values, index, n): + return values[-n:].sum() + + df = DataFrame({"id": [0, 0, 1, 1], "x": [1, 1, 1, 1]}) + grouped_x = df.groupby("id")["x"] + result = grouped_x.agg(sum_last, 1, engine="numba") + expected = Series([1.0] * 2, name="x", index=Index([0, 1], name="id")) + tm.assert_series_equal(result, expected) + + result = grouped_x.agg(sum_last, 2, engine="numba") + expected = Series([2.0] * 2, name="x", index=Index([0, 1], name="id")) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index fbee2361b9b45..8019071be72f3 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -5,6 +5,7 @@ from pandas import ( DataFrame, + Series, option_context, ) import pandas._testing as tm @@ -146,3 +147,20 @@ def test_multifunc_notimplimented(agg_func): with pytest.raises(NotImplementedError, match="Numba engine can"): grouped[1].transform(agg_func, engine="numba") + + +@td.skip_if_no("numba", "0.46.0") +def test_args_not_cached(): + # GH 41647 + def sum_last(values, index, n): + return values[-n:].sum() + + df = DataFrame({"id": [0, 0, 1, 1], "x": [1, 1, 1, 1]}) + grouped_x = df.groupby("id")["x"] + result = grouped_x.transform(sum_last, 1, engine="numba") + expected = Series([1.0] * 4, name="x") + tm.assert_series_equal(result, expected) + + result = grouped_x.transform(sum_last, 2, engine="numba") + expected = Series([2.0] * 4, name="x") + tm.assert_series_equal(result, expected)