Skip to content

ENH add *args support for numba apply #58767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 11, 2024
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ Other
- Bug in :class:`DataFrame` when passing a ``dict`` with a NA scalar and ``columns`` that would always return ``np.nan`` (:issue:`57205`)
- Bug in :func:`eval` where the names of the :class:`Series` were not preserved when using ``engine="numexpr"``. (:issue:`10239`)
- Bug in :func:`unique` on :class:`Index` not always returning :class:`Index` (:issue:`57043`)
- Bug in :meth:`DataFrame.apply` where passing ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`)
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which caused an exception when using NumPy attributes via ``@`` notation, e.g., ``df.eval("@np.floor(a)")``. (:issue:`58041`)
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which did not allow to use ``tan`` function. (:issue:`55091`)
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)
Expand Down
12 changes: 7 additions & 5 deletions pandas/core/_numba/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@

from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import jit_user_function


@functools.cache
def generate_apply_looper(func, nopython=True, nogil=True, parallel=False):
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
nb_compat_func = numba.extending.register_jitable(func)
nb_compat_func = jit_user_function(func)

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def nb_looper(values, axis):
def nb_looper(values, axis, *args):
# Operate on the first row/col in order to get
# the output shape
if axis == 0:
Expand All @@ -33,7 +35,7 @@ def nb_looper(values, axis):
else:
first_elem = values[0]
dim0 = values.shape[0]
res0 = nb_compat_func(first_elem)
res0 = nb_compat_func(first_elem, *args)
# Use np.asarray to get shape for
# https://github.com/numba/numba/issues/4202#issuecomment-1185981507
buf_shape = (dim0,) + np.atleast_1d(np.asarray(res0)).shape
Expand All @@ -44,11 +46,11 @@ def nb_looper(values, axis):
if axis == 1:
buff[0] = res0
for i in numba.prange(1, values.shape[0]):
buff[i] = nb_compat_func(values[i])
buff[i] = nb_compat_func(values[i], *args)
else:
buff[:, 0] = res0
for j in numba.prange(1, values.shape[1]):
buff[:, j] = nb_compat_func(values[:, j])
buff[:, j] = nb_compat_func(values[:, j], *args)
return buff

return nb_looper
Expand Down
36 changes: 23 additions & 13 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
from pandas.core._numba.executor import generate_apply_looper
import pandas.core.common as com
from pandas.core.construction import ensure_wrapped_if_datetimelike
from pandas.core.util.numba_ import (
get_jit_arguments,
prepare_function_arguments,
)

if TYPE_CHECKING:
from collections.abc import (
Expand All @@ -70,7 +74,6 @@
from pandas.core.resample import Resampler
from pandas.core.window.rolling import BaseWindow


ResType = dict[int, Any]


Expand Down Expand Up @@ -972,17 +975,20 @@ def wrapper(*args, **kwargs):
return wrapper

if engine == "numba":
engine_kwargs = {} if engine_kwargs is None else engine_kwargs

args, kwargs = prepare_function_arguments(
self.func, # type: ignore[arg-type]
self.args,
self.kwargs,
)
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
# incompatible type "Callable[..., Any] | str | list[Callable
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
# list[Callable[..., Any] | str]]"; expected "Hashable"
nb_looper = generate_apply_looper(
self.func, # type: ignore[arg-type]
**engine_kwargs,
**get_jit_arguments(engine_kwargs, kwargs),
)
result = nb_looper(self.values, self.axis)
result = nb_looper(self.values, self.axis, *args)
# If we made the result 2-D, squeeze it back to 1-D
result = np.squeeze(result)
else:
Expand Down Expand Up @@ -1123,21 +1129,23 @@ def generate_numba_apply_func(
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names, df_index):
def numba_func(values, col_names, df_index, *args):
results = {}
for j in range(values.shape[1]):
# Create the series
ser = Series(
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
)
results[j] = jitted_udf(ser)
results[j] = jitted_udf(ser, *args)
return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
func, **get_jit_arguments(self.engine_kwargs, kwargs)
)
from pandas.core._numba.extensions import set_numba_data

Expand All @@ -1152,7 +1160,7 @@ def apply_with_numba(self) -> dict[int, Any]:
# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(index) as index, set_numba_data(columns) as columns:
res = dict(nb_func(self.values, columns, index))
res = dict(nb_func(self.values, columns, index, *args))
return res

@property
Expand Down Expand Up @@ -1260,7 +1268,7 @@ def generate_numba_apply_func(
jitted_udf = numba.extending.register_jitable(func)

@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names_index, index):
def numba_func(values, col_names_index, index, *args):
results = {}
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
Expand All @@ -1272,15 +1280,17 @@ def numba_func(values, col_names_index, index):
index=col_names_index,
name=maybe_cast_str(index[i]),
)
results[i] = jitted_udf(ser)
results[i] = jitted_udf(ser, *args)

return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
func, **get_jit_arguments(self.engine_kwargs, kwargs)
)

from pandas.core._numba.extensions import set_numba_data
Expand All @@ -1291,7 +1301,7 @@ def apply_with_numba(self) -> dict[int, Any]:
set_numba_data(self.obj.index) as index,
set_numba_data(self.columns) as columns,
):
res = dict(nb_func(self.values, columns, index))
res = dict(nb_func(self.values, columns, index, *args))

return res

Expand Down
45 changes: 45 additions & 0 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import inspect
import types
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -97,3 +98,47 @@ def jit_user_function(func: Callable) -> Callable:
numba_func = numba.extending.register_jitable(func)

return numba_func


_sentinel = object()


def prepare_function_arguments(
func: Callable, args: tuple, kwargs: dict
) -> tuple[tuple, dict]:
"""
Prepare arguments for jitted function. As numba functions do not support kwargs,
we try to move kwargs into args if possible.

Parameters
----------
func : function
user defined function
args : tuple
user input positional arguments
kwargs : dict
user input keyword arguments

Returns
-------
tuple[tuple, dict]
args, kwargs

"""
if not kwargs:
return args, kwargs

# the udf should have this pattern: def udf(value, *args, **kwargs):...
signature = inspect.signature(func)
arguments = signature.bind(_sentinel, *args, **kwargs)
arguments.apply_defaults()
# Ref: https://peps.python.org/pep-0362/
# Arguments which could be passed as part of either *args or **kwargs
# will be included only in the BoundArguments.args attribute.
args = arguments.args
kwargs = arguments.kwargs

assert args[0] is _sentinel
args = args[1:]

return args, kwargs
30 changes: 26 additions & 4 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,38 @@ def test_apply(float_frame, engine, request):

@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("raw", [True, False])
def test_apply_args(float_frame, axis, raw, engine, request):
if engine == "numba":
mark = pytest.mark.xfail(reason="numba engine doesn't support args")
request.node.add_marker(mark)
def test_apply_args(float_frame, axis, raw, engine):
result = float_frame.apply(
lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine
)
expected = float_frame + 1
tm.assert_frame_equal(result, expected)

# GH:58712
result = float_frame.apply(
lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
)
expected = float_frame + 3
tm.assert_frame_equal(result, expected)

if engine == "numba":
# keyword-only arguments are not supported in numba
with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support kwargs with nopython=True",
):
float_frame.apply(
lambda x, a, *, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if you pass nopython=False in the engine_kwargs?

Copy link
Contributor Author

@auderson auderson May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pandas test raises missing a required argument error when nopython=False.

FAILED
pandas/tests/apply/test_frame_apply.py:63 (test_apply_args[numba-False-True-0])
float_frame =                A         B         C         D
foo_0   0.189053 -0.522748 -0.413064 -2.441467
foo_1   1.799707  1.1441...47218  0.968478 -0.955145
foo_28  0.354112 -1.968397  0.899274 -0.158248
foo_29 -0.967681  1.678419  0.765355  0.045808
axis = 0, raw = True, engine = 'numba', nopython = False

    @pytest.mark.parametrize("axis", [0, 1])
    @pytest.mark.parametrize("raw", [True, False])
    @pytest.mark.parametrize("nopython", [True, False])
    def test_apply_args(float_frame, axis, raw, engine, nopython):
        engine_kwargs = {"nopython": nopython}
        result = float_frame.apply(
            lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine, engine_kwargs=engine_kwargs
        )
        expected = float_frame + 1
        tm.assert_frame_equal(result, expected)
    
        # GH:58712
        result = float_frame.apply(
            lambda x, a, b: x + a + b, args=(1,), b=2, raw=raw, engine=engine, engine_kwargs=engine_kwargs
        )
        expected = float_frame + 3
        tm.assert_frame_equal(result, expected)
    
        if engine == "numba":
            # keyword-only arguments are not supported in numba
            with pytest.raises(
                pd.errors.NumbaUtilError,
                match="numba does not support kwargs with nopython=True",
            ):
>               float_frame.apply(
                    lambda x, a, *, b: x + a + b, args=(1,), b=2, raw=raw, engine=engine, engine_kwargs=engine_kwargs
                )

test_frame_apply.py:88: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../core/frame.py:10353: in apply
    return op.apply().__finalize__(self, method="apply")
../../core/apply.py:886: in apply
    return self.apply_raw(engine=self.engine, engine_kwargs=self.engine_kwargs)
../../core/apply.py:991: in apply_raw
    result = nb_looper(self.values, self.axis, *args)
/home/auderson/miniconda3/envs/pandas-dev/lib/python3.10/site-packages/numba/core/dispatcher.py:468: in _compile_for_args
    error_rewrite(e, 'typing')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

e = TypingError('Failed in nopython mode pipeline (step: nopython frontend)\n\x1b[1m\x1b[1m\x1b[1mNo implementation of fun...0 = values.shape[0]\n\x1b[1m        res0 = nb_compat_func(first_elem, *args)\n\x1b[0m        \x1b[1m^\x1b[0m\x1b[0m\n')
issue_type = 'typing'

    def error_rewrite(e, issue_type):
        """
        Rewrite and raise Exception `e` with help supplied based on the
        specified issue_type.
        """
        if config.SHOW_HELP:
            help_msg = errors.error_extras[issue_type]
            e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
        if config.FULL_TRACEBACKS:
            raise e
        else:
>           raise e.with_traceback(None)
E           numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
E           No implementation of function Function(<function test_apply_args.<locals>.<lambda> at 0x7f0d2f6081f0>) found for signature:
E            
E            >>> <lambda>(readonly array(float64, 1d, A), int64)
E            
E           There are 2 candidate implementations:
E               - Of which 2 did not match due to:
E               Overload in function 'register_jitable.<locals>.wrap.<locals>.ov_wrap': File: numba/core/extending.py: Line 161.
E                 With argument(s): '(readonly array(float64, 1d, A), int64)':
E                Rejected as the implementation raised a specific error:
E                  TypeError: missing a required argument: 'a'
E             raised from /home/auderson/miniconda3/envs/pandas-dev/lib/python3.10/inspect.py:3101
E           
E           During: resolving callee type: Function(<function test_apply_args.<locals>.<lambda> at 0x7f0d2f6081f0>)
E           During: typing of call at /mnt/c/Users/auderson/Documents/Works/OpenSourcePackages/pandas/pandas/core/_numba/executor.py (38)
E           
E           
E           File "../../core/_numba/executor.py", line 38:
E               def nb_looper(values, axis, *args):
E                   <source elided>
E                       dim0 = values.shape[0]
E                   res0 = nb_compat_func(first_elem, *args)
E                   ^

/home/auderson/miniconda3/envs/pandas-dev/lib/python3.10/site-packages/numba/core/dispatcher.py:409: TypingError

An equivalent reproducer:

from numba import jit

@jit(nopython=False)
def foo(*args, kwarg=None):
    print(args)
    print(kwarg)

@jit(nopython=False)
def bar(a, *args):
    foo(a, *args, kwarg='foobar')

bar(1, 2, 3)
UnsupportedError: Failed in object mode pipeline (step: analyzing bytecode)

CALL_FUNCTION_EX with **kwargs not supported.
If you are not using **kwargs this may indicate that
you have a large number of kwargs and are using inlined control
flow. You can resolve this issue by moving the control flow out of
the function call. For example, if you have

    f(a=1 if flag else 0, ...)

Replace that with:

    a_val = 1 if flag else 0
    f(a=a_val, ...)

Copy link
Contributor Author

@auderson auderson May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here need to be modified:

if kwargs and nopython:
raise NumbaUtilError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)

Regardless nopython is True or False, the "numba does not support kwargs" should be raised:

 if kwargs: 
     raise NumbaUtilError( 
         "numba does not support keyword-only arguments" 
         "https://github.com/numba/numba/issues/2916" 
     ) 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, updating the error message is probably the right move here.

)

with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support kwargs with nopython=True",
):
float_frame.apply(
lambda *x, b: x[0] + x[1] + b, args=(1,), b=2, engine=engine, raw=raw
)


def test_apply_categorical_func():
# GH 9573
Expand Down