Skip to content

Commit 3a5fc90

Browse files
committed
add *args for raw numba apply
1 parent a3e751c commit 3a5fc90

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

pandas/core/_numba/executor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def generate_apply_looper(func, nopython=True, nogil=True, parallel=False):
2424
nb_compat_func = numba.extending.register_jitable(func)
2525

2626
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
27-
def nb_looper(values, axis):
27+
def nb_looper(values, axis, *args):
2828
# Operate on the first row/col in order to get
2929
# the output shape
3030
if axis == 0:
@@ -33,7 +33,7 @@ def nb_looper(values, axis):
3333
else:
3434
first_elem = values[0]
3535
dim0 = values.shape[0]
36-
res0 = nb_compat_func(first_elem)
36+
res0 = nb_compat_func(first_elem, *args)
3737
# Use np.asarray to get shape for
3838
# https://github.com/numba/numba/issues/4202#issuecomment-1185981507
3939
buf_shape = (dim0,) + np.atleast_1d(np.asarray(res0)).shape
@@ -44,11 +44,11 @@ def nb_looper(values, axis):
4444
if axis == 1:
4545
buff[0] = res0
4646
for i in numba.prange(1, values.shape[0]):
47-
buff[i] = nb_compat_func(values[i])
47+
buff[i] = nb_compat_func(values[i], *args)
4848
else:
4949
buff[:, 0] = res0
5050
for j in numba.prange(1, values.shape[1]):
51-
buff[:, j] = nb_compat_func(values[:, j])
51+
buff[:, j] = nb_compat_func(values[:, j], *args)
5252
return buff
5353

5454
return nb_looper

pandas/core/apply.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from pandas.core._numba.executor import generate_apply_looper
5252
import pandas.core.common as com
5353
from pandas.core.construction import ensure_wrapped_if_datetimelike
54+
from pandas.core.util.numba_ import get_jit_arguments
5455

5556
if TYPE_CHECKING:
5657
from collections.abc import (
@@ -972,17 +973,15 @@ def wrapper(*args, **kwargs):
972973
return wrapper
973974

974975
if engine == "numba":
975-
engine_kwargs = {} if engine_kwargs is None else engine_kwargs
976-
977976
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
978977
# incompatible type "Callable[..., Any] | str | list[Callable
979978
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
980979
# list[Callable[..., Any] | str]]"; expected "Hashable"
981980
nb_looper = generate_apply_looper(
982981
self.func, # type: ignore[arg-type]
983-
**engine_kwargs,
982+
**get_jit_arguments(engine_kwargs, self.kwargs),
984983
)
985-
result = nb_looper(self.values, self.axis)
984+
result = nb_looper(self.values, self.axis, *self.args)
986985
# If we made the result 2-D, squeeze it back to 1-D
987986
result = np.squeeze(result)
988987
else:

pandas/tests/apply/test_frame_apply.py

+16
Original file line numberDiff line numberDiff line change
@@ -1718,3 +1718,19 @@ def test_agg_dist_like_and_nonunique_columns():
17181718
result = df.agg({"A": "count"})
17191719
expected = df["A"].count()
17201720
tm.assert_series_equal(result, expected)
1721+
1722+
1723+
def test_numba_raw_apply_with_args():
1724+
# GH:58712
1725+
df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
1726+
result = df.apply(lambda x, a, b: x + a + b, args=(1, 2), engine="numba", raw=True)
1727+
# note:
1728+
# result is always float dtype, see core._numba.executor.py:generate_apply_looper
1729+
expected = df + 3.0
1730+
tm.assert_frame_equal(result, expected)
1731+
1732+
with pytest.raises(
1733+
pd.errors.NumbaUtilError,
1734+
match="numba does not support kwargs with nopython=True",
1735+
):
1736+
df.apply(lambda x, a, b: x + a + b, args=(1,), b=2, engine="numba", raw=True)

0 commit comments

Comments
 (0)