Skip to content

ENH add *args support for numba apply #58767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 11, 2024
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ Other
- Bug in :class:`DataFrame` when passing a ``dict`` with a NA scalar and ``columns`` that would always return ``np.nan`` (:issue:`57205`)
- Bug in :func:`eval` where the names of the :class:`Series` were not preserved when using ``engine="numexpr"``. (:issue:`10239`)
- Bug in :func:`unique` on :class:`Index` not always returning :class:`Index` (:issue:`57043`)
- Bug in :meth:`DataFrame.apply` where passing ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`)
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which caused an exception when using NumPy attributes via ``@`` notation, e.g., ``df.eval("@np.floor(a)")``. (:issue:`58041`)
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which did not allow to use ``tan`` function. (:issue:`55091`)
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)
Expand Down
12 changes: 7 additions & 5 deletions pandas/core/_numba/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@

from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import jit_user_function


@functools.cache
def generate_apply_looper(func, nopython=True, nogil=True, parallel=False):
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
nb_compat_func = numba.extending.register_jitable(func)
nb_compat_func = jit_user_function(func)

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def nb_looper(values, axis):
def nb_looper(values, axis, *args):
# Operate on the first row/col in order to get
# the output shape
if axis == 0:
Expand All @@ -33,7 +35,7 @@ def nb_looper(values, axis):
else:
first_elem = values[0]
dim0 = values.shape[0]
res0 = nb_compat_func(first_elem)
res0 = nb_compat_func(first_elem, *args)
# Use np.asarray to get shape for
# https://github.com/numba/numba/issues/4202#issuecomment-1185981507
buf_shape = (dim0,) + np.atleast_1d(np.asarray(res0)).shape
Expand All @@ -44,11 +46,11 @@ def nb_looper(values, axis):
if axis == 1:
buff[0] = res0
for i in numba.prange(1, values.shape[0]):
buff[i] = nb_compat_func(values[i])
buff[i] = nb_compat_func(values[i], *args)
else:
buff[:, 0] = res0
for j in numba.prange(1, values.shape[1]):
buff[:, j] = nb_compat_func(values[:, j])
buff[:, j] = nb_compat_func(values[:, j], *args)
return buff

return nb_looper
Expand Down
36 changes: 23 additions & 13 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
from pandas.core._numba.executor import generate_apply_looper
import pandas.core.common as com
from pandas.core.construction import ensure_wrapped_if_datetimelike
from pandas.core.util.numba_ import (
get_jit_arguments,
prepare_function_arguments,
)

if TYPE_CHECKING:
from collections.abc import (
Expand All @@ -70,7 +74,6 @@
from pandas.core.resample import Resampler
from pandas.core.window.rolling import BaseWindow


ResType = dict[int, Any]


Expand Down Expand Up @@ -972,17 +975,20 @@ def wrapper(*args, **kwargs):
return wrapper

if engine == "numba":
engine_kwargs = {} if engine_kwargs is None else engine_kwargs

args, kwargs = prepare_function_arguments(
self.func, # type: ignore[arg-type]
self.args,
self.kwargs,
)
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
# incompatible type "Callable[..., Any] | str | list[Callable
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
# list[Callable[..., Any] | str]]"; expected "Hashable"
nb_looper = generate_apply_looper(
self.func, # type: ignore[arg-type]
**engine_kwargs,
**get_jit_arguments(engine_kwargs, kwargs),
)
result = nb_looper(self.values, self.axis)
result = nb_looper(self.values, self.axis, *args)
# If we made the result 2-D, squeeze it back to 1-D
result = np.squeeze(result)
else:
Expand Down Expand Up @@ -1123,21 +1129,23 @@ def generate_numba_apply_func(
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names, df_index):
def numba_func(values, col_names, df_index, *args):
results = {}
for j in range(values.shape[1]):
# Create the series
ser = Series(
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
)
results[j] = jitted_udf(ser)
results[j] = jitted_udf(ser, *args)
return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
func, **get_jit_arguments(self.engine_kwargs, kwargs)
)
from pandas.core._numba.extensions import set_numba_data

Expand All @@ -1152,7 +1160,7 @@ def apply_with_numba(self) -> dict[int, Any]:
# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(index) as index, set_numba_data(columns) as columns:
res = dict(nb_func(self.values, columns, index))
res = dict(nb_func(self.values, columns, index, *args))
return res

@property
Expand Down Expand Up @@ -1260,7 +1268,7 @@ def generate_numba_apply_func(
jitted_udf = numba.extending.register_jitable(func)

@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names_index, index):
def numba_func(values, col_names_index, index, *args):
results = {}
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
Expand All @@ -1272,15 +1280,17 @@ def numba_func(values, col_names_index, index):
index=col_names_index,
name=maybe_cast_str(index[i]),
)
results[i] = jitted_udf(ser)
results[i] = jitted_udf(ser, *args)

return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
func, **get_jit_arguments(self.engine_kwargs, kwargs)
)

from pandas.core._numba.extensions import set_numba_data
Expand All @@ -1291,7 +1301,7 @@ def apply_with_numba(self) -> dict[int, Any]:
set_numba_data(self.obj.index) as index,
set_numba_data(self.columns) as columns,
):
res = dict(nb_func(self.values, columns, index))
res = dict(nb_func(self.values, columns, index, *args))

return res

Expand Down
56 changes: 53 additions & 3 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import inspect
import types
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -54,10 +55,15 @@ def get_jit_arguments(
engine_kwargs = {}

nopython = engine_kwargs.get("nopython", True)
if kwargs and nopython:
if kwargs:
# Note: in case numba supports keyword-only arguments in
# a future version, we should remove this check. But this
# seems unlikely to happen soon.

raise NumbaUtilError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
"numba does not support keyword-only arguments"
"https://github.com/numba/numba/issues/2916, "
"https://github.com/numba/numba/issues/6846"
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
Expand Down Expand Up @@ -97,3 +103,47 @@ def jit_user_function(func: Callable) -> Callable:
numba_func = numba.extending.register_jitable(func)

return numba_func


_sentinel = object()


def prepare_function_arguments(
func: Callable, args: tuple, kwargs: dict
) -> tuple[tuple, dict]:
"""
Prepare arguments for jitted function. As numba functions do not support kwargs,
we try to move kwargs into args if possible.

Parameters
----------
func : function
user defined function
args : tuple
user input positional arguments
kwargs : dict
user input keyword arguments

Returns
-------
tuple[tuple, dict]
args, kwargs

"""
if not kwargs:
return args, kwargs

# the udf should have this pattern: def udf(value, *args, **kwargs):...
signature = inspect.signature(func)
arguments = signature.bind(_sentinel, *args, **kwargs)
arguments.apply_defaults()
# Ref: https://peps.python.org/pep-0362/
# Arguments which could be passed as part of either *args or **kwargs
# will be included only in the BoundArguments.args attribute.
args = arguments.args
kwargs = arguments.kwargs

assert args[0] is _sentinel
args = args[1:]

return args, kwargs
54 changes: 49 additions & 5 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,60 @@ def test_apply(float_frame, engine, request):

@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("raw", [True, False])
def test_apply_args(float_frame, axis, raw, engine, request):
if engine == "numba":
mark = pytest.mark.xfail(reason="numba engine doesn't support args")
request.node.add_marker(mark)
@pytest.mark.parametrize("nopython", [True, False])
def test_apply_args(float_frame, axis, raw, engine, nopython):
engine_kwargs = {"nopython": nopython}
result = float_frame.apply(
lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine
lambda x, y: x + y,
axis,
args=(1,),
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)
expected = float_frame + 1
tm.assert_frame_equal(result, expected)

# GH:58712
result = float_frame.apply(
lambda x, a, b: x + a + b,
args=(1,),
b=2,
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)
expected = float_frame + 3
tm.assert_frame_equal(result, expected)

if engine == "numba":
# keyword-only arguments are not supported in numba
with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support keyword-only arguments",
):
float_frame.apply(
lambda x, a, *, b: x + a + b,
args=(1,),
b=2,
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)

with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support keyword-only arguments",
):
float_frame.apply(
lambda *x, b: x[0] + x[1] + b,
args=(1,),
b=2,
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)


def test_apply_categorical_func():
# GH 9573
Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/window/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def f(x):

@td.skip_if_no("numba")
def test_invalid_kwargs_nopython():
with pytest.raises(NumbaUtilError, match="numba does not support kwargs with"):
with pytest.raises(
NumbaUtilError, match="numba does not support keyword-only arguments"
):
Series(range(1)).rolling(1).apply(
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
)
Expand Down