Skip to content

Commit 4c0a490

Browse files
mroeschkeTLouf
authored andcommitted
BUG: groupby.transform/agg caching *args with numba engine (pandas-dev#41656)
1 parent d39fc3b commit 4c0a490

File tree

5 files changed

+62
-14
lines changed

5 files changed

+62
-14
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,7 @@ Groupby/resample/rolling
994994
- Bug in :meth:`DataFrameGroupBy.__getitem__` with non-unique columns incorrectly returning a malformed :class:`SeriesGroupBy` instead of :class:`DataFrameGroupBy` (:issue:`41427`)
995995
- Bug in :meth:`DataFrameGroupBy.transform` with non-unique columns incorrectly raising ``AttributeError`` (:issue:`41427`)
996996
- Bug in :meth:`Resampler.apply` with non-unique columns incorrectly dropping duplicated columns (:issue:`41445`)
997+
- Bug in :meth:`DataFrameGroupBy.transform` and :meth:`DataFrameGroupBy.agg` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`41647`)
997998

998999
Reshaping
9991000
^^^^^^^^^

pandas/core/groupby/groupby.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -1131,10 +1131,16 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
11311131
group_keys = self.grouper._get_group_keys()
11321132

11331133
numba_transform_func = numba_.generate_numba_transform_func(
1134-
tuple(args), kwargs, func, engine_kwargs
1134+
kwargs, func, engine_kwargs
11351135
)
11361136
result = numba_transform_func(
1137-
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns)
1137+
sorted_data,
1138+
sorted_index,
1139+
starts,
1140+
ends,
1141+
len(group_keys),
1142+
len(data.columns),
1143+
*args,
11381144
)
11391145

11401146
cache_key = (func, "groupby_transform")
@@ -1157,11 +1163,15 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
11571163
starts, ends, sorted_index, sorted_data = self._numba_prep(func, data)
11581164
group_keys = self.grouper._get_group_keys()
11591165

1160-
numba_agg_func = numba_.generate_numba_agg_func(
1161-
tuple(args), kwargs, func, engine_kwargs
1162-
)
1166+
numba_agg_func = numba_.generate_numba_agg_func(kwargs, func, engine_kwargs)
11631167
result = numba_agg_func(
1164-
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns)
1168+
sorted_data,
1169+
sorted_index,
1170+
starts,
1171+
ends,
1172+
len(group_keys),
1173+
len(data.columns),
1174+
*args,
11651175
)
11661176

11671177
cache_key = (func, "groupby_agg")

pandas/core/groupby/numba_.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ def f(values, index, ...):
5656

5757

5858
def generate_numba_agg_func(
59-
args: tuple,
6059
kwargs: dict[str, Any],
6160
func: Callable[..., Scalar],
6261
engine_kwargs: dict[str, bool] | None,
63-
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
62+
) -> Callable[
63+
[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int, Any], np.ndarray
64+
]:
6465
"""
6566
Generate a numba jitted agg function specified by values from engine_kwargs.
6667
@@ -72,8 +73,6 @@ def generate_numba_agg_func(
7273
7374
Parameters
7475
----------
75-
args : tuple
76-
*args to be passed into the function
7776
kwargs : dict
7877
**kwargs to be passed into the function
7978
func : function
@@ -103,6 +102,7 @@ def group_agg(
103102
end: np.ndarray,
104103
num_groups: int,
105104
num_columns: int,
105+
*args: Any,
106106
) -> np.ndarray:
107107
result = np.empty((num_groups, num_columns))
108108
for i in numba.prange(num_groups):
@@ -116,11 +116,12 @@ def group_agg(
116116

117117

118118
def generate_numba_transform_func(
119-
args: tuple,
120119
kwargs: dict[str, Any],
121120
func: Callable[..., np.ndarray],
122121
engine_kwargs: dict[str, bool] | None,
123-
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int], np.ndarray]:
122+
) -> Callable[
123+
[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int, Any], np.ndarray
124+
]:
124125
"""
125126
Generate a numba jitted transform function specified by values from engine_kwargs.
126127
@@ -132,8 +133,6 @@ def generate_numba_transform_func(
132133
133134
Parameters
134135
----------
135-
args : tuple
136-
*args to be passed into the function
137136
kwargs : dict
138137
**kwargs to be passed into the function
139138
func : function
@@ -163,6 +162,7 @@ def group_transform(
163162
end: np.ndarray,
164163
num_groups: int,
165164
num_columns: int,
165+
*args: Any,
166166
) -> np.ndarray:
167167
result = np.empty((len(values), num_columns))
168168
for i in numba.prange(num_groups):

pandas/tests/groupby/aggregate/test_numba.py

+19
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
from pandas import (
88
DataFrame,
9+
Index,
910
NamedAgg,
11+
Series,
1012
option_context,
1113
)
1214
import pandas._testing as tm
@@ -154,3 +156,20 @@ def test_multifunc_notimplimented(agg_func):
154156

155157
with pytest.raises(NotImplementedError, match="Numba engine can"):
156158
grouped[1].agg(agg_func, engine="numba")
159+
160+
161+
@td.skip_if_no("numba", "0.46.0")
162+
def test_args_not_cached():
163+
# GH 41647
164+
def sum_last(values, index, n):
165+
return values[-n:].sum()
166+
167+
df = DataFrame({"id": [0, 0, 1, 1], "x": [1, 1, 1, 1]})
168+
grouped_x = df.groupby("id")["x"]
169+
result = grouped_x.agg(sum_last, 1, engine="numba")
170+
expected = Series([1.0] * 2, name="x", index=Index([0, 1], name="id"))
171+
tm.assert_series_equal(result, expected)
172+
173+
result = grouped_x.agg(sum_last, 2, engine="numba")
174+
expected = Series([2.0] * 2, name="x", index=Index([0, 1], name="id"))
175+
tm.assert_series_equal(result, expected)

pandas/tests/groupby/transform/test_numba.py

+18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pandas import (
77
DataFrame,
8+
Series,
89
option_context,
910
)
1011
import pandas._testing as tm
@@ -146,3 +147,20 @@ def test_multifunc_notimplimented(agg_func):
146147

147148
with pytest.raises(NotImplementedError, match="Numba engine can"):
148149
grouped[1].transform(agg_func, engine="numba")
150+
151+
152+
@td.skip_if_no("numba", "0.46.0")
153+
def test_args_not_cached():
154+
# GH 41647
155+
def sum_last(values, index, n):
156+
return values[-n:].sum()
157+
158+
df = DataFrame({"id": [0, 0, 1, 1], "x": [1, 1, 1, 1]})
159+
grouped_x = df.groupby("id")["x"]
160+
result = grouped_x.transform(sum_last, 1, engine="numba")
161+
expected = Series([1.0] * 4, name="x")
162+
tm.assert_series_equal(result, expected)
163+
164+
result = grouped_x.transform(sum_last, 2, engine="numba")
165+
expected = Series([2.0] * 4, name="x")
166+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)