Skip to content

Commit bbe0e53

Browse files
authored
ENH add *args support for numba apply (#58767)
* add *args for raw numba apply * add whatsnew * fix test_case * fix pre-commit * fix test case * add *args for raw=False as well; merge tests together * add prepare_function_arguments * fix mypy * update get_jit_arguments * add nopython test in `test_apply_args` * fix test * fix pre-commit
1 parent 42f785f commit bbe0e53

File tree

6 files changed

+136
-27
lines changed

6 files changed

+136
-27
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ Other
598598
- Bug in :class:`DataFrame` when passing a ``dict`` with a NA scalar and ``columns`` that would always return ``np.nan`` (:issue:`57205`)
599599
- Bug in :func:`eval` where the names of the :class:`Series` were not preserved when using ``engine="numexpr"``. (:issue:`10239`)
600600
- Bug in :func:`unique` on :class:`Index` not always returning :class:`Index` (:issue:`57043`)
601+
- Bug in :meth:`DataFrame.apply` where passing ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`)
601602
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which caused an exception when using NumPy attributes via ``@`` notation, e.g., ``df.eval("@np.floor(a)")``. (:issue:`58041`)
602603
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which did not allow to use ``tan`` function. (:issue:`55091`)
603604
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)

pandas/core/_numba/executor.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@
1414

1515
from pandas.compat._optional import import_optional_dependency
1616

17+
from pandas.core.util.numba_ import jit_user_function
18+
1719

1820
@functools.cache
1921
def generate_apply_looper(func, nopython=True, nogil=True, parallel=False):
2022
if TYPE_CHECKING:
2123
import numba
2224
else:
2325
numba = import_optional_dependency("numba")
24-
nb_compat_func = numba.extending.register_jitable(func)
26+
nb_compat_func = jit_user_function(func)
2527

2628
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
27-
def nb_looper(values, axis):
29+
def nb_looper(values, axis, *args):
2830
# Operate on the first row/col in order to get
2931
# the output shape
3032
if axis == 0:
@@ -33,7 +35,7 @@ def nb_looper(values, axis):
3335
else:
3436
first_elem = values[0]
3537
dim0 = values.shape[0]
36-
res0 = nb_compat_func(first_elem)
38+
res0 = nb_compat_func(first_elem, *args)
3739
# Use np.asarray to get shape for
3840
# https://github.com/numba/numba/issues/4202#issuecomment-1185981507
3941
buf_shape = (dim0,) + np.atleast_1d(np.asarray(res0)).shape
@@ -44,11 +46,11 @@ def nb_looper(values, axis):
4446
if axis == 1:
4547
buff[0] = res0
4648
for i in numba.prange(1, values.shape[0]):
47-
buff[i] = nb_compat_func(values[i])
49+
buff[i] = nb_compat_func(values[i], *args)
4850
else:
4951
buff[:, 0] = res0
5052
for j in numba.prange(1, values.shape[1]):
51-
buff[:, j] = nb_compat_func(values[:, j])
53+
buff[:, j] = nb_compat_func(values[:, j], *args)
5254
return buff
5355

5456
return nb_looper

pandas/core/apply.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
from pandas.core._numba.executor import generate_apply_looper
5252
import pandas.core.common as com
5353
from pandas.core.construction import ensure_wrapped_if_datetimelike
54+
from pandas.core.util.numba_ import (
55+
get_jit_arguments,
56+
prepare_function_arguments,
57+
)
5458

5559
if TYPE_CHECKING:
5660
from collections.abc import (
@@ -70,7 +74,6 @@
7074
from pandas.core.resample import Resampler
7175
from pandas.core.window.rolling import BaseWindow
7276

73-
7477
ResType = dict[int, Any]
7578

7679

@@ -997,17 +1000,20 @@ def wrapper(*args, **kwargs):
9971000
return wrapper
9981001

9991002
if engine == "numba":
1000-
engine_kwargs = {} if engine_kwargs is None else engine_kwargs
1001-
1003+
args, kwargs = prepare_function_arguments(
1004+
self.func, # type: ignore[arg-type]
1005+
self.args,
1006+
self.kwargs,
1007+
)
10021008
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
10031009
# incompatible type "Callable[..., Any] | str | list[Callable
10041010
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
10051011
# list[Callable[..., Any] | str]]"; expected "Hashable"
10061012
nb_looper = generate_apply_looper(
10071013
self.func, # type: ignore[arg-type]
1008-
**engine_kwargs,
1014+
**get_jit_arguments(engine_kwargs, kwargs),
10091015
)
1010-
result = nb_looper(self.values, self.axis)
1016+
result = nb_looper(self.values, self.axis, *args)
10111017
# If we made the result 2-D, squeeze it back to 1-D
10121018
result = np.squeeze(result)
10131019
else:
@@ -1148,21 +1154,23 @@ def generate_numba_apply_func(
11481154
# Currently the parallel argument doesn't get passed through here
11491155
# (it's disabled) since the dicts in numba aren't thread-safe.
11501156
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
1151-
def numba_func(values, col_names, df_index):
1157+
def numba_func(values, col_names, df_index, *args):
11521158
results = {}
11531159
for j in range(values.shape[1]):
11541160
# Create the series
11551161
ser = Series(
11561162
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
11571163
)
1158-
results[j] = jitted_udf(ser)
1164+
results[j] = jitted_udf(ser, *args)
11591165
return results
11601166

11611167
return numba_func
11621168

11631169
def apply_with_numba(self) -> dict[int, Any]:
1170+
func = cast(Callable, self.func)
1171+
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
11641172
nb_func = self.generate_numba_apply_func(
1165-
cast(Callable, self.func), **self.engine_kwargs
1173+
func, **get_jit_arguments(self.engine_kwargs, kwargs)
11661174
)
11671175
from pandas.core._numba.extensions import set_numba_data
11681176

@@ -1177,7 +1185,7 @@ def apply_with_numba(self) -> dict[int, Any]:
11771185
# Convert from numba dict to regular dict
11781186
# Our isinstance checks in the df constructor don't pass for numbas typed dict
11791187
with set_numba_data(index) as index, set_numba_data(columns) as columns:
1180-
res = dict(nb_func(self.values, columns, index))
1188+
res = dict(nb_func(self.values, columns, index, *args))
11811189
return res
11821190

11831191
@property
@@ -1285,7 +1293,7 @@ def generate_numba_apply_func(
12851293
jitted_udf = numba.extending.register_jitable(func)
12861294

12871295
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
1288-
def numba_func(values, col_names_index, index):
1296+
def numba_func(values, col_names_index, index, *args):
12891297
results = {}
12901298
# Currently the parallel argument doesn't get passed through here
12911299
# (it's disabled) since the dicts in numba aren't thread-safe.
@@ -1297,15 +1305,17 @@ def numba_func(values, col_names_index, index):
12971305
index=col_names_index,
12981306
name=maybe_cast_str(index[i]),
12991307
)
1300-
results[i] = jitted_udf(ser)
1308+
results[i] = jitted_udf(ser, *args)
13011309

13021310
return results
13031311

13041312
return numba_func
13051313

13061314
def apply_with_numba(self) -> dict[int, Any]:
1315+
func = cast(Callable, self.func)
1316+
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
13071317
nb_func = self.generate_numba_apply_func(
1308-
cast(Callable, self.func), **self.engine_kwargs
1318+
func, **get_jit_arguments(self.engine_kwargs, kwargs)
13091319
)
13101320

13111321
from pandas.core._numba.extensions import set_numba_data
@@ -1316,7 +1326,7 @@ def apply_with_numba(self) -> dict[int, Any]:
13161326
set_numba_data(self.obj.index) as index,
13171327
set_numba_data(self.columns) as columns,
13181328
):
1319-
res = dict(nb_func(self.values, columns, index))
1329+
res = dict(nb_func(self.values, columns, index, *args))
13201330

13211331
return res
13221332

pandas/core/util/numba_.py

+53-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import inspect
56
import types
67
from typing import (
78
TYPE_CHECKING,
@@ -54,10 +55,15 @@ def get_jit_arguments(
5455
engine_kwargs = {}
5556

5657
nopython = engine_kwargs.get("nopython", True)
57-
if kwargs and nopython:
58+
if kwargs:
59+
# Note: in case numba supports keyword-only arguments in
60+
# a future version, we should remove this check. But this
61+
# seems unlikely to happen soon.
62+
5863
raise NumbaUtilError(
59-
"numba does not support kwargs with nopython=True: "
60-
"https://github.com/numba/numba/issues/2916"
64+
"numba does not support keyword-only arguments"
65+
"https://github.com/numba/numba/issues/2916, "
66+
"https://github.com/numba/numba/issues/6846"
6167
)
6268
nogil = engine_kwargs.get("nogil", False)
6369
parallel = engine_kwargs.get("parallel", False)
@@ -97,3 +103,47 @@ def jit_user_function(func: Callable) -> Callable:
97103
numba_func = numba.extending.register_jitable(func)
98104

99105
return numba_func
106+
107+
108+
_sentinel = object()
109+
110+
111+
def prepare_function_arguments(
112+
func: Callable, args: tuple, kwargs: dict
113+
) -> tuple[tuple, dict]:
114+
"""
115+
Prepare arguments for jitted function. As numba functions do not support kwargs,
116+
we try to move kwargs into args if possible.
117+
118+
Parameters
119+
----------
120+
func : function
121+
user defined function
122+
args : tuple
123+
user input positional arguments
124+
kwargs : dict
125+
user input keyword arguments
126+
127+
Returns
128+
-------
129+
tuple[tuple, dict]
130+
args, kwargs
131+
132+
"""
133+
if not kwargs:
134+
return args, kwargs
135+
136+
# the udf should have this pattern: def udf(value, *args, **kwargs):...
137+
signature = inspect.signature(func)
138+
arguments = signature.bind(_sentinel, *args, **kwargs)
139+
arguments.apply_defaults()
140+
# Ref: https://peps.python.org/pep-0362/
141+
# Arguments which could be passed as part of either *args or **kwargs
142+
# will be included only in the BoundArguments.args attribute.
143+
args = arguments.args
144+
kwargs = arguments.kwargs
145+
146+
assert args[0] is _sentinel
147+
args = args[1:]
148+
149+
return args, kwargs

pandas/tests/apply/test_frame_apply.py

+49-5
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,60 @@ def test_apply(float_frame, engine, request):
6363

6464
@pytest.mark.parametrize("axis", [0, 1])
6565
@pytest.mark.parametrize("raw", [True, False])
66-
def test_apply_args(float_frame, axis, raw, engine, request):
67-
if engine == "numba":
68-
mark = pytest.mark.xfail(reason="numba engine doesn't support args")
69-
request.node.add_marker(mark)
66+
@pytest.mark.parametrize("nopython", [True, False])
67+
def test_apply_args(float_frame, axis, raw, engine, nopython):
68+
engine_kwargs = {"nopython": nopython}
7069
result = float_frame.apply(
71-
lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine
70+
lambda x, y: x + y,
71+
axis,
72+
args=(1,),
73+
raw=raw,
74+
engine=engine,
75+
engine_kwargs=engine_kwargs,
7276
)
7377
expected = float_frame + 1
7478
tm.assert_frame_equal(result, expected)
7579

80+
# GH:58712
81+
result = float_frame.apply(
82+
lambda x, a, b: x + a + b,
83+
args=(1,),
84+
b=2,
85+
raw=raw,
86+
engine=engine,
87+
engine_kwargs=engine_kwargs,
88+
)
89+
expected = float_frame + 3
90+
tm.assert_frame_equal(result, expected)
91+
92+
if engine == "numba":
93+
# keyword-only arguments are not supported in numba
94+
with pytest.raises(
95+
pd.errors.NumbaUtilError,
96+
match="numba does not support keyword-only arguments",
97+
):
98+
float_frame.apply(
99+
lambda x, a, *, b: x + a + b,
100+
args=(1,),
101+
b=2,
102+
raw=raw,
103+
engine=engine,
104+
engine_kwargs=engine_kwargs,
105+
)
106+
107+
with pytest.raises(
108+
pd.errors.NumbaUtilError,
109+
match="numba does not support keyword-only arguments",
110+
):
111+
float_frame.apply(
112+
lambda *x, b: x[0] + x[1] + b,
113+
args=(1,),
114+
b=2,
115+
raw=raw,
116+
engine=engine,
117+
engine_kwargs=engine_kwargs,
118+
)
119+
76120

77121
def test_apply_categorical_func():
78122
# GH 9573

pandas/tests/window/test_numba.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,9 @@ def f(x):
319319

320320
@td.skip_if_no("numba")
321321
def test_invalid_kwargs_nopython():
322-
with pytest.raises(NumbaUtilError, match="numba does not support kwargs with"):
322+
with pytest.raises(
323+
NumbaUtilError, match="numba does not support keyword-only arguments"
324+
):
323325
Series(range(1)).rolling(1).apply(
324326
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
325327
)

0 commit comments

Comments
 (0)