Skip to content

Commit 8be2f8b

Browse files
audersonmroeschke
andauthored
ENH: numba apply supports positional arguments passed as **kwargs (#58995)
* add *args for raw numba apply * add whatsnew * fix test_case * fix pre-commit * fix test case * add *args for raw=False as well; merge tests together * add prepare_function_arguments * fix mypy * update get_jit_arguments * add nopython test in `test_apply_args` * fix test * fix pre-commit * modify prepare_function_arguments * add tests * add tests * add whatsnew * compat for python 3.12 * pre-commit * compat for python 3.12 * update doc; use kw-only * add more tests * update whatsnew * pre-commit * move the tests to test_numba.py * Update doc/source/whatsnew/v3.0.0.rst Co-authored-by: Matthew Roeschke <[email protected]> * Update doc/source/whatsnew/v3.0.0.rst Co-authored-by: Matthew Roeschke <[email protected]> --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 13926e5 commit 8be2f8b

File tree

9 files changed

+187
-38
lines changed

9 files changed

+187
-38
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Other enhancements
5454
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
5555
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
5656
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
57+
- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`)
5758
- :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`)
5859
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
5960
- :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`)

pandas/core/apply.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -994,14 +994,15 @@ def wrapper(*args, **kwargs):
994994
self.func, # type: ignore[arg-type]
995995
self.args,
996996
self.kwargs,
997+
num_required_args=1,
997998
)
998999
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
9991000
# incompatible type "Callable[..., Any] | str | list[Callable
10001001
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
10011002
# list[Callable[..., Any] | str]]"; expected "Hashable"
10021003
nb_looper = generate_apply_looper(
10031004
self.func, # type: ignore[arg-type]
1004-
**get_jit_arguments(engine_kwargs, kwargs),
1005+
**get_jit_arguments(engine_kwargs),
10051006
)
10061007
result = nb_looper(self.values, self.axis, *args)
10071008
# If we made the result 2-D, squeeze it back to 1-D
@@ -1158,9 +1159,11 @@ def numba_func(values, col_names, df_index, *args):
11581159

11591160
def apply_with_numba(self) -> dict[int, Any]:
11601161
func = cast(Callable, self.func)
1161-
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
1162+
args, kwargs = prepare_function_arguments(
1163+
func, self.args, self.kwargs, num_required_args=1
1164+
)
11621165
nb_func = self.generate_numba_apply_func(
1163-
func, **get_jit_arguments(self.engine_kwargs, kwargs)
1166+
func, **get_jit_arguments(self.engine_kwargs)
11641167
)
11651168
from pandas.core._numba.extensions import set_numba_data
11661169

@@ -1298,9 +1301,11 @@ def numba_func(values, col_names_index, index, *args):
12981301

12991302
def apply_with_numba(self) -> dict[int, Any]:
13001303
func = cast(Callable, self.func)
1301-
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
1304+
args, kwargs = prepare_function_arguments(
1305+
func, self.args, self.kwargs, num_required_args=1
1306+
)
13021307
nb_func = self.generate_numba_apply_func(
1303-
func, **get_jit_arguments(self.engine_kwargs, kwargs)
1308+
func, **get_jit_arguments(self.engine_kwargs)
13041309
)
13051310

13061311
from pandas.core._numba.extensions import set_numba_data

pandas/core/groupby/groupby.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class providing the base-class of operations.
136136
from pandas.core.util.numba_ import (
137137
get_jit_arguments,
138138
maybe_use_numba,
139+
prepare_function_arguments,
139140
)
140141

141142
if TYPE_CHECKING:
@@ -1289,8 +1290,11 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
12891290

12901291
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
12911292
numba_.validate_udf(func)
1293+
args, kwargs = prepare_function_arguments(
1294+
func, args, kwargs, num_required_args=2
1295+
)
12921296
numba_transform_func = numba_.generate_numba_transform_func(
1293-
func, **get_jit_arguments(engine_kwargs, kwargs)
1297+
func, **get_jit_arguments(engine_kwargs)
12941298
)
12951299
result = numba_transform_func(
12961300
sorted_data,
@@ -1325,8 +1329,11 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
13251329

13261330
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
13271331
numba_.validate_udf(func)
1332+
args, kwargs = prepare_function_arguments(
1333+
func, args, kwargs, num_required_args=2
1334+
)
13281335
numba_agg_func = numba_.generate_numba_agg_func(
1329-
func, **get_jit_arguments(engine_kwargs, kwargs)
1336+
func, **get_jit_arguments(engine_kwargs)
13301337
)
13311338
result = numba_agg_func(
13321339
sorted_data,

pandas/core/util/numba_.py

+24-23
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,14 @@ def set_use_numba(enable: bool = False) -> None:
2929
GLOBAL_USE_NUMBA = enable
3030

3131

32-
def get_jit_arguments(
33-
engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
34-
) -> dict[str, bool]:
32+
def get_jit_arguments(engine_kwargs: dict[str, bool] | None = None) -> dict[str, bool]:
3533
"""
3634
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
3735
3836
Parameters
3937
----------
4038
engine_kwargs : dict, default None
4139
user passed keyword arguments for numba.JIT
42-
kwargs : dict, default None
43-
user passed keyword arguments to pass into the JITed function
4440
4541
Returns
4642
-------
@@ -55,16 +51,6 @@ def get_jit_arguments(
5551
engine_kwargs = {}
5652

5753
nopython = engine_kwargs.get("nopython", True)
58-
if kwargs:
59-
# Note: in case numba supports keyword-only arguments in
60-
# a future version, we should remove this check. But this
61-
# seems unlikely to happen soon.
62-
63-
raise NumbaUtilError(
64-
"numba does not support keyword-only arguments"
65-
"https://github.com/numba/numba/issues/2916, "
66-
"https://github.com/numba/numba/issues/6846"
67-
)
6854
nogil = engine_kwargs.get("nogil", False)
6955
parallel = engine_kwargs.get("parallel", False)
7056
return {"nopython": nopython, "nogil": nogil, "parallel": parallel}
@@ -109,7 +95,7 @@ def jit_user_function(func: Callable) -> Callable:
10995

11096

11197
def prepare_function_arguments(
112-
func: Callable, args: tuple, kwargs: dict
98+
func: Callable, args: tuple, kwargs: dict, *, num_required_args: int
11399
) -> tuple[tuple, dict]:
114100
"""
115101
Prepare arguments for jitted function. As numba functions do not support kwargs,
@@ -118,11 +104,17 @@ def prepare_function_arguments(
118104
Parameters
119105
----------
120106
func : function
121-
user defined function
107+
User defined function
122108
args : tuple
123-
user input positional arguments
109+
User input positional arguments
124110
kwargs : dict
125-
user input keyword arguments
111+
User input keyword arguments
112+
num_required_args : int
113+
The number of leading positional arguments we will pass to udf.
114+
These are not supplied by the user.
115+
e.g. for groupby we require "values", "index" as the first two arguments:
116+
`numba_func(group, group_index, *args)`, in this case num_required_args=2.
117+
See :func:`pandas.core.groupby.numba_.generate_numba_agg_func`
126118
127119
Returns
128120
-------
@@ -133,17 +125,26 @@ def prepare_function_arguments(
133125
if not kwargs:
134126
return args, kwargs
135127

136-
# the udf should have this pattern: def udf(value, *args, **kwargs):...
128+
# the udf should have this pattern: def udf(arg1, arg2, ..., *args, **kwargs):...
137129
signature = inspect.signature(func)
138-
arguments = signature.bind(_sentinel, *args, **kwargs)
130+
arguments = signature.bind(*[_sentinel] * num_required_args, *args, **kwargs)
139131
arguments.apply_defaults()
140132
# Ref: https://peps.python.org/pep-0362/
141133
# Arguments which could be passed as part of either *args or **kwargs
142134
# will be included only in the BoundArguments.args attribute.
143135
args = arguments.args
144136
kwargs = arguments.kwargs
145137

146-
assert args[0] is _sentinel
147-
args = args[1:]
138+
if kwargs:
139+
# Note: in case numba supports keyword-only arguments in
140+
# a future version, we should remove this check. But this
141+
# seems unlikely to happen soon.
142+
143+
raise NumbaUtilError(
144+
"numba does not support keyword-only arguments"
145+
"https://github.com/numba/numba/issues/2916, "
146+
"https://github.com/numba/numba/issues/6846"
147+
)
148148

149+
args = args[num_required_args:]
149150
return args, kwargs

pandas/core/window/rolling.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from pandas.core.util.numba_ import (
6666
get_jit_arguments,
6767
maybe_use_numba,
68+
prepare_function_arguments,
6869
)
6970
from pandas.core.window.common import (
7071
flex_binary_moment,
@@ -1472,14 +1473,16 @@ def apply(
14721473
if maybe_use_numba(engine):
14731474
if raw is False:
14741475
raise ValueError("raw must be `True` when using the numba engine")
1475-
numba_args = args
1476+
numba_args, kwargs = prepare_function_arguments(
1477+
func, args, kwargs, num_required_args=1
1478+
)
14761479
if self.method == "single":
14771480
apply_func = generate_numba_apply_func(
1478-
func, **get_jit_arguments(engine_kwargs, kwargs)
1481+
func, **get_jit_arguments(engine_kwargs)
14791482
)
14801483
else:
14811484
apply_func = generate_numba_table_func(
1482-
func, **get_jit_arguments(engine_kwargs, kwargs)
1485+
func, **get_jit_arguments(engine_kwargs)
14831486
)
14841487
elif engine in ("cython", None):
14851488
if engine_kwargs is not None:

pandas/tests/apply/test_frame_apply.py

+10
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ def test_apply_args(float_frame, axis, raw, engine, nopython):
9090
tm.assert_frame_equal(result, expected)
9191

9292
if engine == "numba":
93+
# py signature binding
94+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
95+
float_frame.apply(
96+
lambda x, a: x + a,
97+
b=2,
98+
raw=raw,
99+
engine=engine,
100+
engine_kwargs=engine_kwargs,
101+
)
102+
93103
# keyword-only arguments are not supported in numba
94104
with pytest.raises(
95105
pd.errors.NumbaUtilError,

pandas/tests/groupby/aggregate/test_numba.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,43 @@ def incorrect_function(x):
3535
def test_check_nopython_kwargs():
3636
pytest.importorskip("numba")
3737

38-
def incorrect_function(values, index):
39-
return sum(values) * 2.7
38+
def incorrect_function(values, index, *, a):
39+
return sum(values) * 2.7 + a
40+
41+
def correct_function(values, index, a):
42+
return sum(values) * 2.7 + a
4043

4144
data = DataFrame(
4245
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
4346
columns=["key", "data"],
4447
)
48+
expected = data.groupby("key").sum() * 2.7
49+
50+
# py signature binding
51+
with pytest.raises(
52+
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
53+
):
54+
data.groupby("key").agg(incorrect_function, engine="numba", b=1)
55+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
56+
data.groupby("key").agg(correct_function, engine="numba", b=1)
57+
58+
with pytest.raises(
59+
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
60+
):
61+
data.groupby("key")["data"].agg(incorrect_function, engine="numba", b=1)
62+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
63+
data.groupby("key")["data"].agg(correct_function, engine="numba", b=1)
64+
65+
# numba signature check after binding
4566
with pytest.raises(NumbaUtilError, match="numba does not support"):
4667
data.groupby("key").agg(incorrect_function, engine="numba", a=1)
68+
actual = data.groupby("key").agg(correct_function, engine="numba", a=1)
69+
tm.assert_frame_equal(expected + 1, actual)
4770

4871
with pytest.raises(NumbaUtilError, match="numba does not support"):
4972
data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1)
73+
actual = data.groupby("key")["data"].agg(correct_function, engine="numba", a=1)
74+
tm.assert_series_equal(expected["data"] + 1, actual)
5075

5176

5277
@pytest.mark.filterwarnings("ignore")

pandas/tests/groupby/transform/test_numba.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,43 @@ def incorrect_function(x):
3333
def test_check_nopython_kwargs():
3434
pytest.importorskip("numba")
3535

36-
def incorrect_function(values, index):
37-
return values + 1
36+
def incorrect_function(values, index, *, a):
37+
return values + a
38+
39+
def correct_function(values, index, a):
40+
return values + a
3841

3942
data = DataFrame(
4043
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
4144
columns=["key", "data"],
4245
)
46+
# py signature binding
47+
with pytest.raises(
48+
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
49+
):
50+
data.groupby("key").transform(incorrect_function, engine="numba", b=1)
51+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
52+
data.groupby("key").transform(correct_function, engine="numba", b=1)
53+
54+
with pytest.raises(
55+
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
56+
):
57+
data.groupby("key")["data"].transform(incorrect_function, engine="numba", b=1)
58+
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
59+
data.groupby("key")["data"].transform(correct_function, engine="numba", b=1)
60+
61+
# numba signature check after binding
4362
with pytest.raises(NumbaUtilError, match="numba does not support"):
4463
data.groupby("key").transform(incorrect_function, engine="numba", a=1)
64+
actual = data.groupby("key").transform(correct_function, engine="numba", a=1)
65+
tm.assert_frame_equal(data[["data"]] + 1, actual)
4566

4667
with pytest.raises(NumbaUtilError, match="numba does not support"):
4768
data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1)
69+
actual = data.groupby("key")["data"].transform(
70+
correct_function, engine="numba", a=1
71+
)
72+
tm.assert_series_equal(data["data"] + 1, actual)
4873

4974

5075
@pytest.mark.filterwarnings("ignore")

0 commit comments

Comments
 (0)