Skip to content

Commit aa91722

Browse files
committed
modify prepare_function_arguments
1 parent e191be9 commit aa91722

File tree

4 files changed

+32
-30
lines changed

4 files changed

+32
-30
lines changed

pandas/core/apply.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1004,14 +1004,15 @@ def wrapper(*args, **kwargs):
10041004
self.func, # type: ignore[arg-type]
10051005
self.args,
10061006
self.kwargs,
1007+
1,
10071008
)
10081009
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
10091010
# incompatible type "Callable[..., Any] | str | list[Callable
10101011
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
10111012
# list[Callable[..., Any] | str]]"; expected "Hashable"
10121013
nb_looper = generate_apply_looper(
10131014
self.func, # type: ignore[arg-type]
1014-
**get_jit_arguments(engine_kwargs, kwargs),
1015+
**get_jit_arguments(engine_kwargs),
10151016
)
10161017
result = nb_looper(self.values, self.axis, *args)
10171018
# If we made the result 2-D, squeeze it back to 1-D
@@ -1168,9 +1169,9 @@ def numba_func(values, col_names, df_index, *args):
11681169

11691170
def apply_with_numba(self) -> dict[int, Any]:
11701171
func = cast(Callable, self.func)
1171-
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
1172+
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs, 1)
11721173
nb_func = self.generate_numba_apply_func(
1173-
func, **get_jit_arguments(self.engine_kwargs, kwargs)
1174+
func, **get_jit_arguments(self.engine_kwargs)
11741175
)
11751176
from pandas.core._numba.extensions import set_numba_data
11761177

@@ -1313,9 +1314,9 @@ def numba_func(values, col_names_index, index, *args):
13131314

13141315
def apply_with_numba(self) -> dict[int, Any]:
13151316
func = cast(Callable, self.func)
1316-
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
1317+
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs, 1)
13171318
nb_func = self.generate_numba_apply_func(
1318-
func, **get_jit_arguments(self.engine_kwargs, kwargs)
1319+
func, **get_jit_arguments(self.engine_kwargs)
13191320
)
13201321

13211322
from pandas.core._numba.extensions import set_numba_data

pandas/core/groupby/groupby.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class providing the base-class of operations.
137137
from pandas.core.util.numba_ import (
138138
get_jit_arguments,
139139
maybe_use_numba,
140+
prepare_function_arguments,
140141
)
141142

142143
if TYPE_CHECKING:
@@ -1443,8 +1444,9 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
14431444

14441445
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
14451446
numba_.validate_udf(func)
1447+
args, kwargs = prepare_function_arguments(func, args, kwargs, 2)
14461448
numba_transform_func = numba_.generate_numba_transform_func(
1447-
func, **get_jit_arguments(engine_kwargs, kwargs)
1449+
func, **get_jit_arguments(engine_kwargs)
14481450
)
14491451
result = numba_transform_func(
14501452
sorted_data,
@@ -1479,8 +1481,9 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
14791481

14801482
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
14811483
numba_.validate_udf(func)
1484+
args, kwargs = prepare_function_arguments(func, args, kwargs, 2)
14821485
numba_agg_func = numba_.generate_numba_agg_func(
1483-
func, **get_jit_arguments(engine_kwargs, kwargs)
1486+
func, **get_jit_arguments(engine_kwargs)
14841487
)
14851488
result = numba_agg_func(
14861489
sorted_data,

pandas/core/util/numba_.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,14 @@ def set_use_numba(enable: bool = False) -> None:
2929
GLOBAL_USE_NUMBA = enable
3030

3131

32-
def get_jit_arguments(
33-
engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
34-
) -> dict[str, bool]:
32+
def get_jit_arguments(engine_kwargs: dict[str, bool] | None = None) -> dict[str, bool]:
3533
"""
3634
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
3735
3836
Parameters
3937
----------
4038
engine_kwargs : dict, default None
4139
user passed keyword arguments for numba.JIT
42-
kwargs : dict, default None
43-
user passed keyword arguments to pass into the JITed function
4440
4541
Returns
4642
-------
@@ -55,16 +51,6 @@ def get_jit_arguments(
5551
engine_kwargs = {}
5652

5753
nopython = engine_kwargs.get("nopython", True)
58-
if kwargs:
59-
# Note: in case numba supports keyword-only arguments in
60-
# a future version, we should remove this check. But this
61-
# seems unlikely to happen soon.
62-
63-
raise NumbaUtilError(
64-
"numba does not support keyword-only arguments"
65-
"https://github.com/numba/numba/issues/2916, "
66-
"https://github.com/numba/numba/issues/6846"
67-
)
6854
nogil = engine_kwargs.get("nogil", False)
6955
parallel = engine_kwargs.get("parallel", False)
7056
return {"nopython": nopython, "nogil": nogil, "parallel": parallel}
@@ -109,7 +95,7 @@ def jit_user_function(func: Callable) -> Callable:
10995

11096

11197
def prepare_function_arguments(
112-
func: Callable, args: tuple, kwargs: dict
98+
func: Callable, args: tuple, kwargs: dict, num_required_args: int
11399
) -> tuple[tuple, dict]:
114100
"""
115101
Prepare arguments for jitted function. As numba functions do not support kwargs,
@@ -123,6 +109,8 @@ def prepare_function_arguments(
123109
user input positional arguments
124110
kwargs : dict
125111
user input keyword arguments
112+
num_required_args : int
113+
the number of required leading positional arguments for udf.
126114
127115
Returns
128116
-------
@@ -133,17 +121,26 @@ def prepare_function_arguments(
133121
if not kwargs:
134122
return args, kwargs
135123

136-
# the udf should have this pattern: def udf(value, *args, **kwargs):...
124+
# the udf should have this pattern: def udf(arg1, arg2, ..., *args, **kwargs):...
137125
signature = inspect.signature(func)
138-
arguments = signature.bind(_sentinel, *args, **kwargs)
126+
arguments = signature.bind(*[_sentinel] * num_required_args, *args, **kwargs)
139127
arguments.apply_defaults()
140128
# Ref: https://peps.python.org/pep-0362/
141129
# Arguments which could be passed as part of either *args or **kwargs
142130
# will be included only in the BoundArguments.args attribute.
143131
args = arguments.args
144132
kwargs = arguments.kwargs
145133

146-
assert args[0] is _sentinel
147-
args = args[1:]
134+
if kwargs:
135+
# Note: in case numba supports keyword-only arguments in
136+
# a future version, we should remove this check. But this
137+
# seems unlikely to happen soon.
138+
139+
raise NumbaUtilError(
140+
"numba does not support keyword-only arguments"
141+
"https://github.com/numba/numba/issues/2916, "
142+
"https://github.com/numba/numba/issues/6846"
143+
)
148144

145+
args = args[num_required_args:]
149146
return args, kwargs

pandas/core/window/rolling.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from pandas.core.util.numba_ import (
6767
get_jit_arguments,
6868
maybe_use_numba,
69+
prepare_function_arguments,
6970
)
7071
from pandas.core.window.common import (
7172
flex_binary_moment,
@@ -1458,14 +1459,14 @@ def apply(
14581459
if maybe_use_numba(engine):
14591460
if raw is False:
14601461
raise ValueError("raw must be `True` when using the numba engine")
1461-
numba_args = args
1462+
numba_args, kwargs = prepare_function_arguments(func, args, kwargs, 1)
14621463
if self.method == "single":
14631464
apply_func = generate_numba_apply_func(
1464-
func, **get_jit_arguments(engine_kwargs, kwargs)
1465+
func, **get_jit_arguments(engine_kwargs)
14651466
)
14661467
else:
14671468
apply_func = generate_numba_table_func(
1468-
func, **get_jit_arguments(engine_kwargs, kwargs)
1469+
func, **get_jit_arguments(engine_kwargs)
14691470
)
14701471
elif engine in ("cython", None):
14711472
if engine_kwargs is not None:

0 commit comments

Comments
 (0)