From 3a5fc90e8e2e7fa590d081f076b4df1676d6dd06 Mon Sep 17 00:00:00 2001 From: auderson Date: Sat, 18 May 2024 12:30:35 +0800 Subject: [PATCH 01/26] add *args for raw numba apply --- pandas/core/_numba/executor.py | 8 ++++---- pandas/core/apply.py | 7 +++---- pandas/tests/apply/test_frame_apply.py | 16 ++++++++++++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/pandas/core/_numba/executor.py b/pandas/core/_numba/executor.py index 0a26acb7df60a..9935ed5df5afa 100644 --- a/pandas/core/_numba/executor.py +++ b/pandas/core/_numba/executor.py @@ -24,7 +24,7 @@ def generate_apply_looper(func, nopython=True, nogil=True, parallel=False): nb_compat_func = numba.extending.register_jitable(func) @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) - def nb_looper(values, axis): + def nb_looper(values, axis, *args): # Operate on the first row/col in order to get # the output shape if axis == 0: @@ -33,7 +33,7 @@ def nb_looper(values, axis): else: first_elem = values[0] dim0 = values.shape[0] - res0 = nb_compat_func(first_elem) + res0 = nb_compat_func(first_elem, *args) # Use np.asarray to get shape for # https://github.com/numba/numba/issues/4202#issuecomment-1185981507 buf_shape = (dim0,) + np.atleast_1d(np.asarray(res0)).shape @@ -44,11 +44,11 @@ def nb_looper(values, axis): if axis == 1: buff[0] = res0 for i in numba.prange(1, values.shape[0]): - buff[i] = nb_compat_func(values[i]) + buff[i] = nb_compat_func(values[i], *args) else: buff[:, 0] = res0 for j in numba.prange(1, values.shape[1]): - buff[:, j] = nb_compat_func(values[:, j]) + buff[:, j] = nb_compat_func(values[:, j], *args) return buff return nb_looper diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 32e8aea7ea8ab..24da3850152b0 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -51,6 +51,7 @@ from pandas.core._numba.executor import generate_apply_looper import pandas.core.common as com from pandas.core.construction import ensure_wrapped_if_datetimelike +from pandas.core.util.numba_ import get_jit_arguments if TYPE_CHECKING: from collections.abc import ( @@ -972,17 +973,15 @@ def wrapper(*args, **kwargs): return wrapper if engine == "numba": - engine_kwargs = {} if engine_kwargs is None else engine_kwargs - # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has # incompatible type "Callable[..., Any] | str | list[Callable # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | # list[Callable[..., Any] | str]]"; expected "Hashable" nb_looper = generate_apply_looper( self.func, # type: ignore[arg-type] - **engine_kwargs, + **get_jit_arguments(engine_kwargs, self.kwargs), ) - result = nb_looper(self.values, self.axis) + result = nb_looper(self.values, self.axis, *self.args) # If we made the result 2-D, squeeze it back to 1-D result = np.squeeze(result) else: diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index cbc68265a1cc1..27f0779ebef5f 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -1718,3 +1718,19 @@ def test_agg_dist_like_and_nonunique_columns(): result = df.agg({"A": "count"}) expected = df["A"].count() tm.assert_series_equal(result, expected) + + +def test_numba_raw_apply_with_args(): + # GH:58712 + df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + result = df.apply(lambda x, a, b: x + a + b, args=(1, 2), engine="numba", raw=True) + # note: + # result is always float dtype, see core._numba.executor.py:generate_apply_looper + expected = df + 3.0 + tm.assert_frame_equal(result, expected) + + with pytest.raises( + pd.errors.NumbaUtilError, + match="numba does not support kwargs with nopython=True", + ): + df.apply(lambda x, a, b: x + a + b, args=(1,), b=2, engine="numba", raw=True) From 3165efe07f265af5a5f38b0f6e7808167e7cffd9 Mon Sep 17 00:00:00 2001 From: auderson Date: Sat, 18 May 2024 13:52:55 +0800 Subject: [PATCH 02/26] add whatsnew --- doc/source/whatsnew/v3.0.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 731406394ed46..7b39104bb2c66 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -498,6 +498,7 @@ Other - Bug in :class:`DataFrame` when passing a ``dict`` with a NA scalar and ``columns`` that would always return ``np.nan`` (:issue:`57205`) - Bug in :func:`eval` where the names of the :class:`Series` were not preserved when using ``engine="numexpr"``. (:issue:`10239`) - Bug in :func:`unique` on :class:`Index` not always returning :class:`Index` (:issue:`57043`) +- Bug in :meth:`DataFrame.apply` where passing ``raw=True`` and ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`) - Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which caused an exception when using NumPy attributes via ``@`` notation, e.g., ``df.eval("@np.floor(a)")``. (:issue:`58041`) - Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which did not allow to use ``tan`` function. (:issue:`55091`) - Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`) From de8957467eb3ec077a320f294f36fff9d2359a34 Mon Sep 17 00:00:00 2001 From: auderson Date: Sat, 18 May 2024 14:16:16 +0800 Subject: [PATCH 03/26] fix test_case --- pandas/tests/apply/test_frame_apply.py | 27 +++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index 27f0779ebef5f..ca921d98088d2 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -1720,17 +1720,18 @@ def test_agg_dist_like_and_nonunique_columns(): tm.assert_series_equal(result, expected) -def test_numba_raw_apply_with_args(): - # GH:58712 - df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - result = df.apply(lambda x, a, b: x + a + b, args=(1, 2), engine="numba", raw=True) - # note: - # result is always float dtype, see core._numba.executor.py:generate_apply_looper - expected = df + 3.0 - tm.assert_frame_equal(result, expected) +def test_numba_raw_apply_with_args(engine): + if engine == "numba": + # GH:58712 + df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + result = df.apply(lambda x, a, b: x + a + b, args=(1, 2), engine=engine, raw=True) + # note: + # result is always float dtype, see core._numba.executor.py:generate_apply_looper + expected = df + 3.0 + tm.assert_frame_equal(result, expected) - with pytest.raises( - pd.errors.NumbaUtilError, - match="numba does not support kwargs with nopython=True", - ): - df.apply(lambda x, a, b: x + a + b, args=(1,), b=2, engine="numba", raw=True) + with pytest.raises( + pd.errors.NumbaUtilError, + match="numba does not support kwargs with nopython=True", + ): + df.apply(lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=True) From 3f13b303114b7a1fe2671b58327e1cf0a29fe2a3 Mon Sep 17 00:00:00 2001 From: auderson Date: Sat, 18 May 2024 14:19:50 +0800 Subject: [PATCH 04/26] fix pre-commit --- pandas/tests/apply/test_frame_apply.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index ca921d98088d2..de2d15c3760de 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -1724,9 +1724,11 @@ def test_numba_raw_apply_with_args(engine): if engine == "numba": # GH:58712 df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - result = df.apply(lambda x, a, b: x + a + b, args=(1, 2), engine=engine, raw=True) - # note: - # result is always float dtype, see core._numba.executor.py:generate_apply_looper + result = df.apply( + lambda x, a, b: x + a + b, args=(1, 2), engine=engine, raw=True + ) + # note: result is always float dtype, + # see core._numba.executor.py:generate_apply_looper expected = df + 3.0 tm.assert_frame_equal(result, expected) From c0268459694c02662f1e263625e0197212b486f5 Mon Sep 17 00:00:00 2001 From: auderson Date: Sat, 18 May 2024 22:11:57 +0800 Subject: [PATCH 05/26] fix test case --- pandas/tests/apply/test_frame_apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index de2d15c3760de..3ff156cac3e35 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -64,7 +64,7 @@ def test_apply(float_frame, engine, request): @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("raw", [True, False]) def test_apply_args(float_frame, axis, raw, engine, request): - if engine == "numba": + if engine == "numba" and raw is False: mark = pytest.mark.xfail(reason="numba engine doesn't support args") request.node.add_marker(mark) result = float_frame.apply( From 96581a316ad9e3d6acb62e724af9cb2c273d9fb2 Mon Sep 17 00:00:00 2001 From: auderson Date: Sun, 19 May 2024 11:53:23 +0800 Subject: [PATCH 06/26] add *args for raw=False as well; merge tests together --- doc/source/whatsnew/v3.0.0.rst | 2 +- pandas/core/apply.py | 18 ++++++++------ pandas/tests/apply/test_frame_apply.py | 34 +++++++++----------------- 3 files changed, 22 insertions(+), 32 deletions(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 7b39104bb2c66..7cb6af877be42 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -498,7 +498,7 @@ Other - Bug in :class:`DataFrame` when passing a ``dict`` with a NA scalar and ``columns`` that would always return ``np.nan`` (:issue:`57205`) - Bug in :func:`eval` where the names of the :class:`Series` were not preserved when using ``engine="numexpr"``. (:issue:`10239`) - Bug in :func:`unique` on :class:`Index` not always returning :class:`Index` (:issue:`57043`) -- Bug in :meth:`DataFrame.apply` where passing ``raw=True`` and ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`) +- Bug in :meth:`DataFrame.apply` where passing ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`) - Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which caused an exception when using NumPy attributes via ``@`` notation, e.g., ``df.eval("@np.floor(a)")``. (:issue:`58041`) - Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which did not allow to use ``tan`` function. (:issue:`55091`) - Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`) diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 24da3850152b0..bfc268047b8da 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -1122,21 +1122,22 @@ def generate_numba_apply_func( # Currently the parallel argument doesn't get passed through here # (it's disabled) since the dicts in numba aren't thread-safe. @numba.jit(nogil=nogil, nopython=nopython, parallel=parallel) - def numba_func(values, col_names, df_index): + def numba_func(values, col_names, df_index, *args): results = {} for j in range(values.shape[1]): # Create the series ser = Series( values[:, j], index=df_index, name=maybe_cast_str(col_names[j]) ) - results[j] = jitted_udf(ser) + results[j] = jitted_udf(ser, *args) return results return numba_func def apply_with_numba(self) -> dict[int, Any]: nb_func = self.generate_numba_apply_func( - cast(Callable, self.func), **self.engine_kwargs + cast(Callable, self.func), + **get_jit_arguments(self.engine_kwargs, self.kwargs), ) from pandas.core._numba.extensions import set_numba_data @@ -1151,7 +1152,7 @@ def apply_with_numba(self) -> dict[int, Any]: # Convert from numba dict to regular dict # Our isinstance checks in the df constructor don't pass for numbas typed dict with set_numba_data(index) as index, set_numba_data(columns) as columns: - res = dict(nb_func(self.values, columns, index)) + res = dict(nb_func(self.values, columns, index, *self.args)) return res @property @@ -1259,7 +1260,7 @@ def generate_numba_apply_func( jitted_udf = numba.extending.register_jitable(func) @numba.jit(nogil=nogil, nopython=nopython, parallel=parallel) - def numba_func(values, col_names_index, index): + def numba_func(values, col_names_index, index, *args): results = {} # Currently the parallel argument doesn't get passed through here # (it's disabled) since the dicts in numba aren't thread-safe. @@ -1271,7 +1272,7 @@ def numba_func(values, col_names_index, index): index=col_names_index, name=maybe_cast_str(index[i]), ) - results[i] = jitted_udf(ser) + results[i] = jitted_udf(ser, *args) return results @@ -1279,7 +1280,8 @@ def numba_func(values, col_names_index, index): def apply_with_numba(self) -> dict[int, Any]: nb_func = self.generate_numba_apply_func( - cast(Callable, self.func), **self.engine_kwargs + cast(Callable, self.func), + **get_jit_arguments(self.engine_kwargs, self.kwargs), ) from pandas.core._numba.extensions import set_numba_data @@ -1290,7 +1292,7 @@ def apply_with_numba(self) -> dict[int, Any]: set_numba_data(self.obj.index) as index, set_numba_data(self.columns) as columns, ): - res = dict(nb_func(self.values, columns, index)) + res = dict(nb_func(self.values, columns, index, *self.args)) return res diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index 3ff156cac3e35..32a1ed596380a 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -63,16 +63,23 @@ def test_apply(float_frame, engine, request): @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("raw", [True, False]) -def test_apply_args(float_frame, axis, raw, engine, request): - if engine == "numba" and raw is False: - mark = pytest.mark.xfail(reason="numba engine doesn't support args") - request.node.add_marker(mark) +def test_apply_args(float_frame, axis, raw, engine): + # GH:58712 result = float_frame.apply( lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine ) expected = float_frame + 1 tm.assert_frame_equal(result, expected) + if engine == "numba": + with pytest.raises( + pd.errors.NumbaUtilError, + match="numba does not support kwargs with nopython=True", + ): + float_frame.apply( + lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw + ) + def test_apply_categorical_func(): # GH 9573 @@ -1718,22 +1725,3 @@ def test_agg_dist_like_and_nonunique_columns(): result = df.agg({"A": "count"}) expected = df["A"].count() tm.assert_series_equal(result, expected) - - -def test_numba_raw_apply_with_args(engine): - if engine == "numba": - # GH:58712 - df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - result = df.apply( - lambda x, a, b: x + a + b, args=(1, 2), engine=engine, raw=True - ) - # note: result is always float dtype, - # see core._numba.executor.py:generate_apply_looper - expected = df + 3.0 - tm.assert_frame_equal(result, expected) - - with pytest.raises( - pd.errors.NumbaUtilError, - match="numba does not support kwargs with nopython=True", - ): - df.apply(lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=True) From 2aae933f96a9e11c10f8a98f71f49d9b565fd9be Mon Sep 17 00:00:00 2001 From: auderson Date: Tue, 21 May 2024 09:21:25 +0800 Subject: [PATCH 07/26] add prepare_function_arguments --- pandas/core/_numba/executor.py | 4 ++- pandas/core/apply.py | 20 ++++++++---- pandas/core/util/numba_.py | 45 ++++++++++++++++++++++++++ pandas/tests/apply/test_frame_apply.py | 19 +++++++++-- 4 files changed, 78 insertions(+), 10 deletions(-) diff --git a/pandas/core/_numba/executor.py b/pandas/core/_numba/executor.py index 9935ed5df5afa..82fd4e34ac67b 100644 --- a/pandas/core/_numba/executor.py +++ b/pandas/core/_numba/executor.py @@ -14,6 +14,8 @@ from pandas.compat._optional import import_optional_dependency +from pandas.core.util.numba_ import jit_user_function + @functools.cache def generate_apply_looper(func, nopython=True, nogil=True, parallel=False): @@ -21,7 +23,7 @@ def generate_apply_looper(func, nopython=True, nogil=True, parallel=False): import numba else: numba = import_optional_dependency("numba") - nb_compat_func = numba.extending.register_jitable(func) + nb_compat_func = jit_user_function(func) @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) def nb_looper(values, axis, *args): diff --git a/pandas/core/apply.py b/pandas/core/apply.py index bfc268047b8da..7137fc8d71c71 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -51,7 +51,10 @@ from pandas.core._numba.executor import generate_apply_looper import pandas.core.common as com from pandas.core.construction import ensure_wrapped_if_datetimelike -from pandas.core.util.numba_ import get_jit_arguments +from pandas.core.util.numba_ import ( + get_jit_arguments, + prepare_function_arguments, +) if TYPE_CHECKING: from collections.abc import ( @@ -973,15 +976,16 @@ def wrapper(*args, **kwargs): return wrapper if engine == "numba": + args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs) # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has # incompatible type "Callable[..., Any] | str | list[Callable # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | # list[Callable[..., Any] | str]]"; expected "Hashable" nb_looper = generate_apply_looper( self.func, # type: ignore[arg-type] - **get_jit_arguments(engine_kwargs, self.kwargs), + **get_jit_arguments(engine_kwargs, kwargs), ) - result = nb_looper(self.values, self.axis, *self.args) + result = nb_looper(self.values, self.axis, *args) # If we made the result 2-D, squeeze it back to 1-D result = np.squeeze(result) else: @@ -1135,9 +1139,10 @@ def numba_func(values, col_names, df_index, *args): return numba_func def apply_with_numba(self) -> dict[int, Any]: + args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs) nb_func = self.generate_numba_apply_func( cast(Callable, self.func), - **get_jit_arguments(self.engine_kwargs, self.kwargs), + **get_jit_arguments(self.engine_kwargs, kwargs), ) from pandas.core._numba.extensions import set_numba_data @@ -1152,7 +1157,7 @@ def apply_with_numba(self) -> dict[int, Any]: # Convert from numba dict to regular dict # Our isinstance checks in the df constructor don't pass for numbas typed dict with set_numba_data(index) as index, set_numba_data(columns) as columns: - res = dict(nb_func(self.values, columns, index, *self.args)) + res = dict(nb_func(self.values, columns, index, *args)) return res @property @@ -1279,9 +1284,10 @@ def numba_func(values, col_names_index, index, *args): return numba_func def apply_with_numba(self) -> dict[int, Any]: + args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs) nb_func = self.generate_numba_apply_func( cast(Callable, self.func), - **get_jit_arguments(self.engine_kwargs, self.kwargs), + **get_jit_arguments(self.engine_kwargs, kwargs), ) from pandas.core._numba.extensions import set_numba_data @@ -1292,7 +1298,7 @@ def apply_with_numba(self) -> dict[int, Any]: set_numba_data(self.obj.index) as index, set_numba_data(self.columns) as columns, ): - res = dict(nb_func(self.values, columns, index, *self.args)) + res = dict(nb_func(self.values, columns, index, *args)) return res diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index a6079785e7475..da02c4b5ccf34 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -2,6 +2,7 @@ from __future__ import annotations +import inspect import types from typing import ( TYPE_CHECKING, @@ -97,3 +98,47 @@ def jit_user_function(func: Callable) -> Callable: numba_func = numba.extending.register_jitable(func) return numba_func + + +_sentinel = object() + + +def prepare_function_arguments( + func: Callable, args: tuple, kwargs: dict +) -> tuple[tuple, dict]: + """ + Prepare arguments for jitted function. As numba functions do not support kwargs, + we try to move kwargs into args if possible. + + Parameters + ---------- + func : function + user defined function + args : tuple + user input positional arguments + kwargs : dict + user input keyword arguments + + Returns + ------- + tuple[tuple, dict] + args, kwargs + + """ + if not kwargs: + return args, kwargs + + # the udf should have this pattern: def udf(value, *args, **kwargs):... + signature = inspect.signature(func) + arguments = signature.bind(_sentinel, *args, **kwargs) + arguments.apply_defaults() + # Ref: https://peps.python.org/pep-0362/ + # Arguments which could be passed as part of either *args or **kwargs + # will be included only in the BoundArguments.args attribute. + args = arguments.args + kwargs = arguments.kwargs + + assert args[0] is _sentinel + args = args[1:] + + return args, kwargs diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index 32a1ed596380a..003c3594a9441 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -64,20 +64,35 @@ def test_apply(float_frame, engine, request): @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("raw", [True, False]) def test_apply_args(float_frame, axis, raw, engine): - # GH:58712 result = float_frame.apply( lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine ) expected = float_frame + 1 tm.assert_frame_equal(result, expected) + # GH:58712 + result = float_frame.apply( + lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw + ) + expected = float_frame + 3 + tm.assert_frame_equal(result, expected) + if engine == "numba": + # keyword-only arguments are not supported in numba + with pytest.raises( + pd.errors.NumbaUtilError, + match="numba does not support kwargs with nopython=True", + ): + float_frame.apply( + lambda x, a, *, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw + ) + with pytest.raises( pd.errors.NumbaUtilError, match="numba does not support kwargs with nopython=True", ): float_frame.apply( - lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw + lambda *x, b: x[0] + x[1] + b, args=(1,), b=2, engine=engine, raw=raw ) From 1a6f1aeeba772e4fe6382df2306e4b5d93bf21a3 Mon Sep 17 00:00:00 2001 From: auderson Date: Tue, 21 May 2024 10:45:39 +0800 Subject: [PATCH 08/26] fix mypy --- pandas/core/apply.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 7137fc8d71c71..4500d4e8a50be 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -74,7 +74,6 @@ from pandas.core.resample import Resampler from pandas.core.window.rolling import BaseWindow - ResType = dict[int, Any] @@ -976,7 +975,11 @@ def wrapper(*args, **kwargs): return wrapper if engine == "numba": - args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs) + args, kwargs = prepare_function_arguments( + self.func, # type: ignore[arg-type] + self.args, + self.kwargs, + ) # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has # incompatible type "Callable[..., Any] | str | list[Callable # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | @@ -1139,10 +1142,10 @@ def numba_func(values, col_names, df_index, *args): return numba_func def apply_with_numba(self) -> dict[int, Any]: - args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs) + func = cast(Callable, self.func) + args, kwargs = prepare_function_arguments(func, self.args, self.kwargs) nb_func = self.generate_numba_apply_func( - cast(Callable, self.func), - **get_jit_arguments(self.engine_kwargs, kwargs), + func, **get_jit_arguments(self.engine_kwargs, kwargs) ) from pandas.core._numba.extensions import set_numba_data @@ -1284,10 +1287,10 @@ def numba_func(values, col_names_index, index, *args): return numba_func def apply_with_numba(self) -> dict[int, Any]: - args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs) + func = cast(Callable, self.func) + args, kwargs = prepare_function_arguments(func, self.args, self.kwargs) nb_func = self.generate_numba_apply_func( - cast(Callable, self.func), - **get_jit_arguments(self.engine_kwargs, kwargs), + func, **get_jit_arguments(self.engine_kwargs, kwargs) ) from pandas.core._numba.extensions import set_numba_data From 8925b3a307d0d5667a10b5b10b4f205afca11896 Mon Sep 17 00:00:00 2001 From: auderson Date: Sun, 26 May 2024 10:31:21 +0800 Subject: [PATCH 09/26] update get_jit_arguments --- pandas/core/util/numba_.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index da02c4b5ccf34..d93984d210cb4 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -55,10 +55,15 @@ def get_jit_arguments( engine_kwargs = {} nopython = engine_kwargs.get("nopython", True) - if kwargs and nopython: + if kwargs: + # Note: in case numba supports keyword-only arguments in + # a future version, we should remove this check. But this + # seems unlikely to happen soon. + raise NumbaUtilError( - "numba does not support kwargs with nopython=True: " - "https://github.com/numba/numba/issues/2916" + "numba does not support keyword-only arguments" + "https://github.com/numba/numba/issues/2916, " + "https://github.com/numba/numba/issues/6846" ) nogil = engine_kwargs.get("nogil", False) parallel = engine_kwargs.get("parallel", False) From 085ae73f68a4c5e4597f542a70bd132af39103f1 Mon Sep 17 00:00:00 2001 From: auderson Date: Sun, 26 May 2024 10:31:44 +0800 Subject: [PATCH 10/26] add nopython test in `test_apply_args` --- pandas/tests/apply/test_frame_apply.py | 36 +++++++++++++++++++++----- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index 003c3594a9441..939997f44c1a9 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -63,16 +63,28 @@ def test_apply(float_frame, engine, request): @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("raw", [True, False]) -def test_apply_args(float_frame, axis, raw, engine): +@pytest.mark.parametrize("nopython", [True, False]) +def test_apply_args(float_frame, axis, raw, engine, nopython): + engine_kwargs = {"nopython": nopython} result = float_frame.apply( - lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine + lambda x, y: x + y, + axis, + args=(1,), + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, ) expected = float_frame + 1 tm.assert_frame_equal(result, expected) # GH:58712 result = float_frame.apply( - lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw + lambda x, a, b: x + a + b, + args=(1,), + b=2, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, ) expected = float_frame + 3 tm.assert_frame_equal(result, expected) @@ -81,18 +93,28 @@ def test_apply_args(float_frame, axis, raw, engine): # keyword-only arguments are not supported in numba with pytest.raises( pd.errors.NumbaUtilError, - match="numba does not support kwargs with nopython=True", + match="numba does not support keyword-only arguments", ): float_frame.apply( - lambda x, a, *, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw + lambda x, a, *, b: x + a + b, + args=(1,), + b=2, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, ) with pytest.raises( pd.errors.NumbaUtilError, - match="numba does not support kwargs with nopython=True", + match="numba does not support keyword-only arguments", ): float_frame.apply( - lambda *x, b: x[0] + x[1] + b, args=(1,), b=2, engine=engine, raw=raw + lambda *x, b: x[0] + x[1] + b, + args=(1,), + b=2, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, ) From c75e0b72143b96cf18096ef4c828025bc5b4544f Mon Sep 17 00:00:00 2001 From: auderson Date: Sun, 26 May 2024 11:27:13 +0800 Subject: [PATCH 11/26] fix test --- pandas/tests/window/test_numba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index 650eb911e410b..62e17db595985 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -304,7 +304,7 @@ def f(x): @td.skip_if_no("numba") def test_invalid_kwargs_nopython(): - with pytest.raises(NumbaUtilError, match="numba does not support kwargs with"): + with pytest.raises(NumbaUtilError, match="numba does not support keyword-only arguments"): Series(range(1)).rolling(1).apply( lambda x: x, kwargs={"a": 1}, engine="numba", raw=True ) From ceb817836778842d4317ab8969dab4609825d377 Mon Sep 17 00:00:00 2001 From: auderson Date: Sun, 26 May 2024 11:31:38 +0800 Subject: [PATCH 12/26] fix pre-commit --- pandas/tests/window/test_numba.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index 62e17db595985..c743df859337b 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -304,7 +304,9 @@ def f(x): @td.skip_if_no("numba") def test_invalid_kwargs_nopython(): - with pytest.raises(NumbaUtilError, match="numba does not support keyword-only arguments"): + with pytest.raises( + NumbaUtilError, match="numba does not support keyword-only arguments" + ): Series(range(1)).rolling(1).apply( lambda x: x, kwargs={"a": 1}, engine="numba", raw=True ) From aa91722f843dd3a33e083dbc2151efc59085e0be Mon Sep 17 00:00:00 2001 From: auderson Date: Thu, 13 Jun 2024 10:42:21 +0800 Subject: [PATCH 13/26] modify prepare_function_arguments --- pandas/core/apply.py | 11 +++++----- pandas/core/groupby/groupby.py | 7 +++++-- pandas/core/util/numba_.py | 37 ++++++++++++++++------------------ pandas/core/window/rolling.py | 7 ++++--- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 75ad17b59bf88..2b8a4a3f10402 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -1004,6 +1004,7 @@ def wrapper(*args, **kwargs): self.func, # type: ignore[arg-type] self.args, self.kwargs, + 1, ) # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has # incompatible type "Callable[..., Any] | str | list[Callable @@ -1011,7 +1012,7 @@ def wrapper(*args, **kwargs): # list[Callable[..., Any] | str]]"; expected "Hashable" nb_looper = generate_apply_looper( self.func, # type: ignore[arg-type] - **get_jit_arguments(engine_kwargs, kwargs), + **get_jit_arguments(engine_kwargs), ) result = nb_looper(self.values, self.axis, *args) # If we made the result 2-D, squeeze it back to 1-D @@ -1168,9 +1169,9 @@ def numba_func(values, col_names, df_index, *args): def apply_with_numba(self) -> dict[int, Any]: func = cast(Callable, self.func) - args, kwargs = prepare_function_arguments(func, self.args, self.kwargs) + args, kwargs = prepare_function_arguments(func, self.args, self.kwargs, 1) nb_func = self.generate_numba_apply_func( - func, **get_jit_arguments(self.engine_kwargs, kwargs) + func, **get_jit_arguments(self.engine_kwargs) ) from pandas.core._numba.extensions import set_numba_data @@ -1313,9 +1314,9 @@ def numba_func(values, col_names_index, index, *args): def apply_with_numba(self) -> dict[int, Any]: func = cast(Callable, self.func) - args, kwargs = prepare_function_arguments(func, self.args, self.kwargs) + args, kwargs = prepare_function_arguments(func, self.args, self.kwargs, 1) nb_func = self.generate_numba_apply_func( - func, **get_jit_arguments(self.engine_kwargs, kwargs) + func, **get_jit_arguments(self.engine_kwargs) ) from pandas.core._numba.extensions import set_numba_data diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 1b58317c08736..700e5e4ce5529 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -137,6 +137,7 @@ class providing the base-class of operations. from pandas.core.util.numba_ import ( get_jit_arguments, maybe_use_numba, + prepare_function_arguments, ) if TYPE_CHECKING: @@ -1443,8 +1444,9 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs): starts, ends, sorted_index, sorted_data = self._numba_prep(df) numba_.validate_udf(func) + args, kwargs = prepare_function_arguments(func, args, kwargs, 2) numba_transform_func = numba_.generate_numba_transform_func( - func, **get_jit_arguments(engine_kwargs, kwargs) + func, **get_jit_arguments(engine_kwargs) ) result = numba_transform_func( sorted_data, @@ -1479,8 +1481,9 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs): starts, ends, sorted_index, sorted_data = self._numba_prep(df) numba_.validate_udf(func) + args, kwargs = prepare_function_arguments(func, args, kwargs, 2) numba_agg_func = numba_.generate_numba_agg_func( - func, **get_jit_arguments(engine_kwargs, kwargs) + func, **get_jit_arguments(engine_kwargs) ) result = numba_agg_func( sorted_data, diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index d93984d210cb4..e3b9778ce1aa1 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -29,9 +29,7 @@ def set_use_numba(enable: bool = False) -> None: GLOBAL_USE_NUMBA = enable -def get_jit_arguments( - engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None -) -> dict[str, bool]: +def get_jit_arguments(engine_kwargs: dict[str, bool] | None = None) -> dict[str, bool]: """ Return arguments to pass to numba.JIT, falling back on pandas default JIT settings. @@ -39,8 +37,6 @@ def get_jit_arguments( ---------- engine_kwargs : dict, default None user passed keyword arguments for numba.JIT - kwargs : dict, default None - user passed keyword arguments to pass into the JITed function Returns ------- @@ -55,16 +51,6 @@ def get_jit_arguments( engine_kwargs = {} nopython = engine_kwargs.get("nopython", True) - if kwargs: - # Note: in case numba supports keyword-only arguments in - # a future version, we should remove this check. But this - # seems unlikely to happen soon. - - raise NumbaUtilError( - "numba does not support keyword-only arguments" - "https://github.com/numba/numba/issues/2916, " - "https://github.com/numba/numba/issues/6846" - ) nogil = engine_kwargs.get("nogil", False) parallel = engine_kwargs.get("parallel", False) return {"nopython": nopython, "nogil": nogil, "parallel": parallel} @@ -109,7 +95,7 @@ def jit_user_function(func: Callable) -> Callable: def prepare_function_arguments( - func: Callable, args: tuple, kwargs: dict + func: Callable, args: tuple, kwargs: dict, num_required_args: int ) -> tuple[tuple, dict]: """ Prepare arguments for jitted function. As numba functions do not support kwargs, @@ -123,6 +109,8 @@ def prepare_function_arguments( user input positional arguments kwargs : dict user input keyword arguments + num_required_args : int + the number of required leading positional arguments for udf. Returns ------- @@ -133,9 +121,9 @@ def prepare_function_arguments( if not kwargs: return args, kwargs - # the udf should have this pattern: def udf(value, *args, **kwargs):... + # the udf should have this pattern: def udf(arg1, arg2, ..., *args, **kwargs):... signature = inspect.signature(func) - arguments = signature.bind(_sentinel, *args, **kwargs) + arguments = signature.bind(*[_sentinel] * num_required_args, *args, **kwargs) arguments.apply_defaults() # Ref: https://peps.python.org/pep-0362/ # Arguments which could be passed as part of either *args or **kwargs @@ -143,7 +131,16 @@ def prepare_function_arguments( args = arguments.args kwargs = arguments.kwargs - assert args[0] is _sentinel - args = args[1:] + if kwargs: + # Note: in case numba supports keyword-only arguments in + # a future version, we should remove this check. But this + # seems unlikely to happen soon. + + raise NumbaUtilError( + "numba does not support keyword-only arguments" + "https://github.com/numba/numba/issues/2916, " + "https://github.com/numba/numba/issues/6846" + ) + args = args[num_required_args:] return args, kwargs diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 2243d8dd1a613..06ba6d4f6091e 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -66,6 +66,7 @@ from pandas.core.util.numba_ import ( get_jit_arguments, maybe_use_numba, + prepare_function_arguments, ) from pandas.core.window.common import ( flex_binary_moment, @@ -1458,14 +1459,14 @@ def apply( if maybe_use_numba(engine): if raw is False: raise ValueError("raw must be `True` when using the numba engine") - numba_args = args + numba_args, kwargs = prepare_function_arguments(func, args, kwargs, 1) if self.method == "single": apply_func = generate_numba_apply_func( - func, **get_jit_arguments(engine_kwargs, kwargs) + func, **get_jit_arguments(engine_kwargs) ) else: apply_func = generate_numba_table_func( - func, **get_jit_arguments(engine_kwargs, kwargs) + func, **get_jit_arguments(engine_kwargs) ) elif engine in ("cython", None): if engine_kwargs is not None: From 0de3224bd453f9760c71766e203fd331e1673f6d Mon Sep 17 00:00:00 2001 From: auderson Date: Thu, 13 Jun 2024 11:34:04 +0800 Subject: [PATCH 14/26] add tests --- pandas/tests/apply/test_frame_apply.py | 10 ++++++++++ pandas/tests/groupby/aggregate/test_numba.py | 21 ++++++++++++++++++-- pandas/tests/groupby/transform/test_numba.py | 21 ++++++++++++++++++-- pandas/tests/window/test_numba.py | 13 +++++++++++- 4 files changed, 60 insertions(+), 5 deletions(-) diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index 939997f44c1a9..8649a6bc90d2d 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -90,6 +90,16 @@ def test_apply_args(float_frame, axis, raw, engine, nopython): tm.assert_frame_equal(result, expected) if engine == "numba": + # py signature binding + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + float_frame.apply( + lambda x, a: x + a, + b=2, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, + ) + # keyword-only arguments are not supported in numba with pytest.raises( pd.errors.NumbaUtilError, diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index 964a80f8f3310..dc01bc5d5ada1 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -35,18 +35,35 @@ def incorrect_function(x): def test_check_nopython_kwargs(): pytest.importorskip("numba") - def incorrect_function(values, index): - return sum(values) * 2.7 + def incorrect_function(values, index, *, a): + return sum(values) * 2.7 + a + + def correct_function(values, index, a): + return sum(values) * 2.7 + a data = DataFrame( {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=["key", "data"], ) + # py signature binding + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key").agg(incorrect_function, engine="numba", b=1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key").agg(correct_function, engine="numba", b=1) + + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key")["data"].agg(incorrect_function, engine="numba", b=1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key")["data"].agg(correct_function, engine="numba", b=1) + + # numba signature check after binding with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key").agg(incorrect_function, engine="numba", a=1) + data.groupby("key").agg(correct_function, engine="numba", a=1) with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1) + data.groupby("key")["data"].agg(correct_function, engine="numba", a=1) @pytest.mark.filterwarnings("ignore") diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index a17d25b2e7e2e..239fd7e428dab 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -33,18 +33,35 @@ def incorrect_function(x): def test_check_nopython_kwargs(): pytest.importorskip("numba") - def incorrect_function(values, index): - return values + 1 + def incorrect_function(values, index, *, a): + return values + a + + def correct_function(values, index, a): + return values + a data = DataFrame( {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=["key", "data"], ) + # py signature binding + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key").transform(incorrect_function, engine="numba", b=1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key").transform(correct_function, engine="numba", b=1) + + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key")["data"].transform(incorrect_function, engine="numba", b=1) + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + data.groupby("key")["data"].transform(correct_function, engine="numba", b=1) + + # numba signature check after binding with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key").transform(incorrect_function, engine="numba", a=1) + data.groupby("key").transform(correct_function, engine="numba", a=1) with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1) + data.groupby("key")["data"].transform(correct_function, engine="numba", a=1) @pytest.mark.filterwarnings("ignore") diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index 23b17c651f08d..88eb40b46f730 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -319,13 +319,24 @@ def f(x): @td.skip_if_no("numba") def test_invalid_kwargs_nopython(): + with pytest.raises(TypeError, match="got an unexpected keyword argument 'a'"): + Series(range(1)).rolling(1).apply( + lambda x: x, kwargs={"a": 1}, engine="numba", raw=True + ) with pytest.raises( NumbaUtilError, match="numba does not support keyword-only arguments" ): Series(range(1)).rolling(1).apply( - lambda x: x, kwargs={"a": 1}, engine="numba", raw=True + lambda x, *, a: x, kwargs={"a": 1}, engine="numba", raw=True ) + tm.assert_series_equal( + Series(range(1), dtype=float) + 1, + Series(range(1)) + .rolling(1) + .apply(lambda x, a: (x + a).sum(), kwargs={"a": 1}, engine="numba", raw=True), + ) + @td.skip_if_no("numba") @pytest.mark.slow From 82252be711848eba16425e832b9db4521cdab76a Mon Sep 17 00:00:00 2001 From: auderson Date: Thu, 13 Jun 2024 13:54:33 +0800 Subject: [PATCH 15/26] add tests --- pandas/tests/groupby/aggregate/test_numba.py | 8 ++++++-- pandas/tests/groupby/transform/test_numba.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index dc01bc5d5ada1..edbd7cf58fcbf 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -45,6 +45,8 @@ def correct_function(values, index, a): {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=["key", "data"], ) + expected = data.groupby("key").sum() * 2.7 + # py signature binding with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key").agg(incorrect_function, engine="numba", b=1) @@ -59,11 +61,13 @@ def correct_function(values, index, a): # numba signature check after binding with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key").agg(incorrect_function, engine="numba", a=1) - data.groupby("key").agg(correct_function, engine="numba", a=1) + actual = data.groupby("key").agg(correct_function, engine="numba", a=1) + tm.assert_frame_equal(expected + 1, actual) with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1) - data.groupby("key")["data"].agg(correct_function, engine="numba", a=1) + actual = data.groupby("key")["data"].agg(correct_function, engine="numba", a=1) + tm.assert_series_equal(expected["data"] + 1, actual) @pytest.mark.filterwarnings("ignore") diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index 239fd7e428dab..bba0712447351 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -57,11 +57,15 @@ def correct_function(values, index, a): # numba signature check after binding with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key").transform(incorrect_function, engine="numba", a=1) - data.groupby("key").transform(correct_function, engine="numba", a=1) + actual = data.groupby("key").transform(correct_function, engine="numba", a=1) + tm.assert_frame_equal(data[["data"]] + 1, actual) with pytest.raises(NumbaUtilError, match="numba does not support"): data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1) - data.groupby("key")["data"].transform(correct_function, engine="numba", a=1) + actual = data.groupby("key")["data"].transform( + correct_function, engine="numba", a=1 + ) + tm.assert_series_equal(data["data"] + 1, actual) @pytest.mark.filterwarnings("ignore") From da6dbc7b46a5c10a6b230472464f9a3ac4bdc0f2 Mon Sep 17 00:00:00 2001 From: auderson Date: Thu, 13 Jun 2024 14:00:34 +0800 Subject: [PATCH 16/26] add whatsnew --- doc/source/whatsnew/v3.0.0.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 4a02622ae9eda..8b4a063afee06 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -47,6 +47,8 @@ Other enhancements - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) +- numba apply now supports positional arguments passed as kwargs (:issue:`58995`) + .. --------------------------------------------------------------------------- .. _whatsnew_300.notable_bug_fixes: From e72bfb20589f9e5d0e89dd62e23ba3fe12f8d848 Mon Sep 17 00:00:00 2001 From: auderson Date: Thu, 13 Jun 2024 15:21:37 +0800 Subject: [PATCH 17/26] compat for python 3.12 --- pandas/tests/groupby/aggregate/test_numba.py | 4 ++-- pandas/tests/groupby/transform/test_numba.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index edbd7cf58fcbf..3a843915661e5 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -48,12 +48,12 @@ def correct_function(values, index, a): expected = data.groupby("key").sum() * 2.7 # py signature binding - with pytest.raises(TypeError, match="missing a required argument: 'a'"): + with pytest.raises(TypeError, match="missing a required (keyword-only argument|argument): 'a'"): data.groupby("key").agg(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key").agg(correct_function, engine="numba", b=1) - with pytest.raises(TypeError, match="missing a required argument: 'a'"): + with pytest.raises(TypeError, match="missing a required (keyword-only argument|argument): 'a'"): data.groupby("key")["data"].agg(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key")["data"].agg(correct_function, engine="numba", b=1) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index bba0712447351..c90914c5caf09 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -44,12 +44,12 @@ def correct_function(values, index, a): columns=["key", "data"], ) # py signature binding - with pytest.raises(TypeError, match="missing a required argument: 'a'"): + with pytest.raises(TypeError, match="missing a required (keyword-only argument|argument): 'a'"): data.groupby("key").transform(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key").transform(correct_function, engine="numba", b=1) - with pytest.raises(TypeError, match="missing a required argument: 'a'"): + with pytest.raises(TypeError, match="missing a required (keyword-only argument|argument): 'a'"): data.groupby("key")["data"].transform(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key")["data"].transform(correct_function, engine="numba", b=1) From 8168d9b8974d59a4f0ee5c533d84481bc5542930 Mon Sep 17 00:00:00 2001 From: auderson Date: Thu, 13 Jun 2024 15:26:44 +0800 Subject: [PATCH 18/26] pre-commit --- pandas/tests/groupby/aggregate/test_numba.py | 8 ++++++-- pandas/tests/groupby/transform/test_numba.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index 3a843915661e5..15c1efe5fd1ff 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -48,12 +48,16 @@ def correct_function(values, index, a): expected = data.groupby("key").sum() * 2.7 # py signature binding - with pytest.raises(TypeError, match="missing a required (keyword-only argument|argument): 'a'"): + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): data.groupby("key").agg(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key").agg(correct_function, engine="numba", b=1) - with pytest.raises(TypeError, match="missing a required (keyword-only argument|argument): 'a'"): + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): data.groupby("key")["data"].agg(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key")["data"].agg(correct_function, engine="numba", b=1) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index c90914c5caf09..969df8ef4c52b 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -44,12 +44,16 @@ def correct_function(values, index, a): columns=["key", "data"], ) # py signature binding - with pytest.raises(TypeError, match="missing a required (keyword-only argument|argument): 'a'"): + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): data.groupby("key").transform(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key").transform(correct_function, engine="numba", b=1) - with pytest.raises(TypeError, match="missing a required (keyword-only argument|argument): 'a'"): + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): data.groupby("key")["data"].transform(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key")["data"].transform(correct_function, engine="numba", b=1) From c211119725e112da17852a448cee410e17069b26 Mon Sep 17 00:00:00 2001 From: auderson Date: Thu, 13 Jun 2024 15:21:37 +0800 Subject: [PATCH 19/26] compat for python 3.12 --- pandas/tests/groupby/aggregate/test_numba.py | 8 ++++++-- pandas/tests/groupby/transform/test_numba.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pandas/tests/groupby/aggregate/test_numba.py b/pandas/tests/groupby/aggregate/test_numba.py index edbd7cf58fcbf..15c1efe5fd1ff 100644 --- a/pandas/tests/groupby/aggregate/test_numba.py +++ b/pandas/tests/groupby/aggregate/test_numba.py @@ -48,12 +48,16 @@ def correct_function(values, index, a): expected = data.groupby("key").sum() * 2.7 # py signature binding - with pytest.raises(TypeError, match="missing a required argument: 'a'"): + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): data.groupby("key").agg(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key").agg(correct_function, engine="numba", b=1) - with pytest.raises(TypeError, match="missing a required argument: 'a'"): + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): data.groupby("key")["data"].agg(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key")["data"].agg(correct_function, engine="numba", b=1) diff --git a/pandas/tests/groupby/transform/test_numba.py b/pandas/tests/groupby/transform/test_numba.py index bba0712447351..969df8ef4c52b 100644 --- a/pandas/tests/groupby/transform/test_numba.py +++ b/pandas/tests/groupby/transform/test_numba.py @@ -44,12 +44,16 @@ def correct_function(values, index, a): columns=["key", "data"], ) # py signature binding - with pytest.raises(TypeError, match="missing a required argument: 'a'"): + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): data.groupby("key").transform(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key").transform(correct_function, engine="numba", b=1) - with pytest.raises(TypeError, match="missing a required argument: 'a'"): + with pytest.raises( + TypeError, match="missing a required (keyword-only argument|argument): 'a'" + ): data.groupby("key")["data"].transform(incorrect_function, engine="numba", b=1) with pytest.raises(TypeError, match="missing a required argument: 'a'"): data.groupby("key")["data"].transform(correct_function, engine="numba", b=1) From f7936f2c2292f2464dbdf84c6523448628b218dd Mon Sep 17 00:00:00 2001 From: auderson Date: Wed, 23 Oct 2024 10:44:28 +0800 Subject: [PATCH 20/26] update doc; use kw-only --- pandas/core/apply.py | 10 +++++++--- pandas/core/groupby/groupby.py | 8 ++++++-- pandas/core/util/numba_.py | 14 +++++++++----- pandas/core/window/rolling.py | 4 +++- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/pandas/core/apply.py b/pandas/core/apply.py index d8ad295e627f0..af513d49bcfe0 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -994,7 +994,7 @@ def wrapper(*args, **kwargs): self.func, # type: ignore[arg-type] self.args, self.kwargs, - 1, + num_required_args=1, ) # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has # incompatible type "Callable[..., Any] | str | list[Callable @@ -1159,7 +1159,9 @@ def numba_func(values, col_names, df_index, *args): def apply_with_numba(self) -> dict[int, Any]: func = cast(Callable, self.func) - args, kwargs = prepare_function_arguments(func, self.args, self.kwargs, 1) + args, kwargs = prepare_function_arguments( + func, self.args, self.kwargs, num_required_args=1 + ) nb_func = self.generate_numba_apply_func( func, **get_jit_arguments(self.engine_kwargs) ) @@ -1299,7 +1301,9 @@ def numba_func(values, col_names_index, index, *args): def apply_with_numba(self) -> dict[int, Any]: func = cast(Callable, self.func) - args, kwargs = prepare_function_arguments(func, self.args, self.kwargs, 1) + args, kwargs = prepare_function_arguments( + func, self.args, self.kwargs, num_required_args=1 + ) nb_func = self.generate_numba_apply_func( func, **get_jit_arguments(self.engine_kwargs) ) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 7d9f690a68e2b..221a7c876aa6f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1276,7 +1276,9 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs): starts, ends, sorted_index, sorted_data = self._numba_prep(df) numba_.validate_udf(func) - args, kwargs = prepare_function_arguments(func, args, kwargs, 2) + args, kwargs = prepare_function_arguments( + func, args, kwargs, num_required_args=2 + ) numba_transform_func = numba_.generate_numba_transform_func( func, **get_jit_arguments(engine_kwargs) ) @@ -1313,7 +1315,9 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs): starts, ends, sorted_index, sorted_data = self._numba_prep(df) numba_.validate_udf(func) - args, kwargs = prepare_function_arguments(func, args, kwargs, 2) + args, kwargs = prepare_function_arguments( + func, args, kwargs, num_required_args=2 + ) numba_agg_func = numba_.generate_numba_agg_func( func, **get_jit_arguments(engine_kwargs) ) diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index 033939fce0744..d3f00c08e0e2c 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -95,7 +95,7 @@ def jit_user_function(func: Callable) -> Callable: def prepare_function_arguments( - func: Callable, args: tuple, kwargs: dict, num_required_args: int + func: Callable, args: tuple, kwargs: dict, *, num_required_args: int ) -> tuple[tuple, dict]: """ Prepare arguments for jitted function. As numba functions do not support kwargs, @@ -104,13 +104,17 @@ def prepare_function_arguments( Parameters ---------- func : function - user defined function + User defined function args : tuple - user input positional arguments + User input positional arguments kwargs : dict - user input keyword arguments + User input keyword arguments num_required_args : int - the number of required leading positional arguments for udf. + The number of leading positional arguments we will pass to udf. + These are not supplied by the user. + e.g. for groupby we require "values", "index" as the first two arguments: + `numba_func(group, group_index, *args)`, in this case num_required_args=2. + See :func:`pandas.core.groupby.numba_.generate_numba_agg_func` Returns ------- diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index aa7b6b953ef15..b1c37ab48fa57 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -1473,7 +1473,9 @@ def apply( if maybe_use_numba(engine): if raw is False: raise ValueError("raw must be `True` when using the numba engine") - numba_args, kwargs = prepare_function_arguments(func, args, kwargs, 1) + numba_args, kwargs = prepare_function_arguments( + func, args, kwargs, num_required_args=1 + ) if self.method == "single": apply_func = generate_numba_apply_func( func, **get_jit_arguments(engine_kwargs) From 2400d286481b8fa6343081cb8f6be7606b79f09d Mon Sep 17 00:00:00 2001 From: auderson Date: Wed, 30 Oct 2024 10:40:12 +0800 Subject: [PATCH 21/26] add more tests --- pandas/tests/window/test_apply.py | 14 ++++++++++ pandas/tests/window/test_expanding.py | 15 ++++++++++ pandas/tests/window/test_groupby.py | 40 +++++++++++++++++++++++++++ 3 files changed, 69 insertions(+) diff --git a/pandas/tests/window/test_apply.py b/pandas/tests/window/test_apply.py index 2398713585cfb..11ad9ea12b207 100644 --- a/pandas/tests/window/test_apply.py +++ b/pandas/tests/window/test_apply.py @@ -316,3 +316,17 @@ def test_center_reindex_frame(raw): ) frame_rs = frame.rolling(window=25, min_periods=minp, center=True).apply(f, raw=raw) tm.assert_frame_equal(frame_xp, frame_rs) + +def test_apply_numba_with_kwargs(): + # 58995 + def func(sr, a=0): + return sr.sum() + a + + data = DataFrame(range(10)) + + result = data.rolling(5).apply(func, engine="numba", raw=True, kwargs={"a": 1}) + expected = data.rolling(5).sum() + 1 + tm.assert_frame_equal(result, expected) + + result = data.rolling(5).apply(func, engine="numba", raw=True, args=(1,)) + tm.assert_frame_equal(result, expected) \ No newline at end of file diff --git a/pandas/tests/window/test_expanding.py b/pandas/tests/window/test_expanding.py index b2f76bdd0e2ad..0950dbc337f1b 100644 --- a/pandas/tests/window/test_expanding.py +++ b/pandas/tests/window/test_expanding.py @@ -691,3 +691,18 @@ def test_numeric_only_corr_cov_series(kernel, use_arg, numeric_only, dtype): op2 = getattr(expanding2, kernel) expected = op2(*arg2, numeric_only=numeric_only) tm.assert_series_equal(result, expected) + + +def test_apply_numba_with_kwargs(): + # 58995 + def func(sr, a=0): + return sr.sum() + a + + data = DataFrame(range(10)) + + result = data.expanding().apply(func, engine="numba", raw=True, kwargs={"a": 1}) + expected = data.expanding().sum() + 1 + tm.assert_frame_equal(result, expected) + + result = data.expanding().apply(func, engine="numba", raw=True, args=(1,)) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/window/test_groupby.py b/pandas/tests/window/test_groupby.py index 4d37c6d57f788..804002641206f 100644 --- a/pandas/tests/window/test_groupby.py +++ b/pandas/tests/window/test_groupby.py @@ -1024,6 +1024,26 @@ def test_datelike_on_not_monotonic_within_each_group(self): with pytest.raises(ValueError, match="Each group within B must be monotonic."): df.groupby("A").rolling("365D", on="B") + def test_groupby_rolling_apply_numba_with_kwargs(self, roll_frame): + def func(sr, a=0): + return sr.sum() + a + + # 58995 + result = ( + roll_frame.groupby("A") + .rolling(5) + .apply(func, engine="numba", raw=True, kwargs={"a": 1}) + ) + expected = roll_frame.groupby("A").rolling(5).sum() + 1 + tm.assert_frame_equal(result, expected) + + result = ( + roll_frame.groupby("A") + .rolling(5) + .apply(func, engine="numba", raw=True, args=(1,)) + ) + tm.assert_frame_equal(result, expected) + class TestExpanding: @pytest.fixture @@ -1134,6 +1154,26 @@ def test_expanding_apply(self, raw, frame): expected.index = expected_index tm.assert_frame_equal(result, expected) + def test_groupby_expanding_apply_numba_with_kwargs(self, roll_frame): + # 58995 + def func(sr, a=0): + return sr.sum() + a + + result = ( + roll_frame.groupby("A") + .expanding() + .apply(func, engine="numba", raw=True, kwargs={"a": 1}) + ) + expected = roll_frame.groupby("A").expanding().sum() + 1 + tm.assert_frame_equal(result, expected) + + result = ( + roll_frame.groupby("A") + .expanding() + .apply(func, engine="numba", raw=True, args=(1,)) + ) + tm.assert_frame_equal(result, expected) + class TestEWM: @pytest.mark.parametrize( From 8d1021184bd36bdee057571c99ae394a90fd26d6 Mon Sep 17 00:00:00 2001 From: auderson Date: Wed, 30 Oct 2024 10:40:28 +0800 Subject: [PATCH 22/26] update whatsnew --- doc/source/whatsnew/v3.0.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 2de18ecfc5c1b..89f84a93d6dd7 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -54,6 +54,7 @@ Other enhancements - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`) +- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply`: when using numba engine in these apply methods, positional arguments now can be passed as kwargs (:issue:`58995`) - :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`) - :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`) - :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`) @@ -62,7 +63,6 @@ Other enhancements - Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`) - Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) -- numba apply now supports positional arguments passed as kwargs (:issue:`58995`) .. --------------------------------------------------------------------------- From f672f9bcadc1dc8253b7e3ee2f677e4ee9a3380c Mon Sep 17 00:00:00 2001 From: auderson Date: Wed, 30 Oct 2024 10:41:51 +0800 Subject: [PATCH 23/26] pre-commit --- pandas/tests/window/test_apply.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pandas/tests/window/test_apply.py b/pandas/tests/window/test_apply.py index 11ad9ea12b207..482f88f78c6f4 100644 --- a/pandas/tests/window/test_apply.py +++ b/pandas/tests/window/test_apply.py @@ -317,6 +317,7 @@ def test_center_reindex_frame(raw): frame_rs = frame.rolling(window=25, min_periods=minp, center=True).apply(f, raw=raw) tm.assert_frame_equal(frame_xp, frame_rs) + def test_apply_numba_with_kwargs(): # 58995 def func(sr, a=0): @@ -329,4 +330,4 @@ def func(sr, a=0): tm.assert_frame_equal(result, expected) result = data.rolling(5).apply(func, engine="numba", raw=True, args=(1,)) - tm.assert_frame_equal(result, expected) \ No newline at end of file + tm.assert_frame_equal(result, expected) From 1eba10b13adb37a9aef7ca57a907bcc01ffde7ba Mon Sep 17 00:00:00 2001 From: auderson Date: Wed, 30 Oct 2024 10:56:57 +0800 Subject: [PATCH 24/26] move the tests to test_numba.py --- pandas/tests/window/test_apply.py | 15 ------- pandas/tests/window/test_expanding.py | 15 ------- pandas/tests/window/test_groupby.py | 40 ------------------ pandas/tests/window/test_numba.py | 61 +++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 70 deletions(-) diff --git a/pandas/tests/window/test_apply.py b/pandas/tests/window/test_apply.py index 482f88f78c6f4..2398713585cfb 100644 --- a/pandas/tests/window/test_apply.py +++ b/pandas/tests/window/test_apply.py @@ -316,18 +316,3 @@ def test_center_reindex_frame(raw): ) frame_rs = frame.rolling(window=25, min_periods=minp, center=True).apply(f, raw=raw) tm.assert_frame_equal(frame_xp, frame_rs) - - -def test_apply_numba_with_kwargs(): - # 58995 - def func(sr, a=0): - return sr.sum() + a - - data = DataFrame(range(10)) - - result = data.rolling(5).apply(func, engine="numba", raw=True, kwargs={"a": 1}) - expected = data.rolling(5).sum() + 1 - tm.assert_frame_equal(result, expected) - - result = data.rolling(5).apply(func, engine="numba", raw=True, args=(1,)) - tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/window/test_expanding.py b/pandas/tests/window/test_expanding.py index 0950dbc337f1b..b2f76bdd0e2ad 100644 --- a/pandas/tests/window/test_expanding.py +++ b/pandas/tests/window/test_expanding.py @@ -691,18 +691,3 @@ def test_numeric_only_corr_cov_series(kernel, use_arg, numeric_only, dtype): op2 = getattr(expanding2, kernel) expected = op2(*arg2, numeric_only=numeric_only) tm.assert_series_equal(result, expected) - - -def test_apply_numba_with_kwargs(): - # 58995 - def func(sr, a=0): - return sr.sum() + a - - data = DataFrame(range(10)) - - result = data.expanding().apply(func, engine="numba", raw=True, kwargs={"a": 1}) - expected = data.expanding().sum() + 1 - tm.assert_frame_equal(result, expected) - - result = data.expanding().apply(func, engine="numba", raw=True, args=(1,)) - tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/window/test_groupby.py b/pandas/tests/window/test_groupby.py index 804002641206f..4d37c6d57f788 100644 --- a/pandas/tests/window/test_groupby.py +++ b/pandas/tests/window/test_groupby.py @@ -1024,26 +1024,6 @@ def test_datelike_on_not_monotonic_within_each_group(self): with pytest.raises(ValueError, match="Each group within B must be monotonic."): df.groupby("A").rolling("365D", on="B") - def test_groupby_rolling_apply_numba_with_kwargs(self, roll_frame): - def func(sr, a=0): - return sr.sum() + a - - # 58995 - result = ( - roll_frame.groupby("A") - .rolling(5) - .apply(func, engine="numba", raw=True, kwargs={"a": 1}) - ) - expected = roll_frame.groupby("A").rolling(5).sum() + 1 - tm.assert_frame_equal(result, expected) - - result = ( - roll_frame.groupby("A") - .rolling(5) - .apply(func, engine="numba", raw=True, args=(1,)) - ) - tm.assert_frame_equal(result, expected) - class TestExpanding: @pytest.fixture @@ -1154,26 +1134,6 @@ def test_expanding_apply(self, raw, frame): expected.index = expected_index tm.assert_frame_equal(result, expected) - def test_groupby_expanding_apply_numba_with_kwargs(self, roll_frame): - # 58995 - def func(sr, a=0): - return sr.sum() + a - - result = ( - roll_frame.groupby("A") - .expanding() - .apply(func, engine="numba", raw=True, kwargs={"a": 1}) - ) - expected = roll_frame.groupby("A").expanding().sum() + 1 - tm.assert_frame_equal(result, expected) - - result = ( - roll_frame.groupby("A") - .expanding() - .apply(func, engine="numba", raw=True, args=(1,)) - ) - tm.assert_frame_equal(result, expected) - class TestEWM: @pytest.mark.parametrize( diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py index 88eb40b46f730..d9ab4723a8f2c 100644 --- a/pandas/tests/window/test_numba.py +++ b/pandas/tests/window/test_numba.py @@ -38,6 +38,11 @@ def arithmetic_numba_supported_operators(request): return request.param +@pytest.fixture +def roll_frame(): + return DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)}) + + @td.skip_if_no("numba") @pytest.mark.filterwarnings("ignore") # Filter warnings when parallel=True and the function can't be parallelized by Numba @@ -67,6 +72,62 @@ def f(x, *args): ) tm.assert_series_equal(result, expected) + def test_apply_numba_with_kwargs(self, roll_frame): + # GH 58995 + # rolling apply + def func(sr, a=0): + return sr.sum() + a + + data = DataFrame(range(10)) + + result = data.rolling(5).apply(func, engine="numba", raw=True, kwargs={"a": 1}) + expected = data.rolling(5).sum() + 1 + tm.assert_frame_equal(result, expected) + + result = data.rolling(5).apply(func, engine="numba", raw=True, args=(1,)) + tm.assert_frame_equal(result, expected) + + # expanding apply + + result = data.expanding().apply(func, engine="numba", raw=True, kwargs={"a": 1}) + expected = data.expanding().sum() + 1 + tm.assert_frame_equal(result, expected) + + result = data.expanding().apply(func, engine="numba", raw=True, args=(1,)) + tm.assert_frame_equal(result, expected) + + # groupby rolling + result = ( + roll_frame.groupby("A") + .rolling(5) + .apply(func, engine="numba", raw=True, kwargs={"a": 1}) + ) + expected = roll_frame.groupby("A").rolling(5).sum() + 1 + tm.assert_frame_equal(result, expected) + + result = ( + roll_frame.groupby("A") + .rolling(5) + .apply(func, engine="numba", raw=True, args=(1,)) + ) + tm.assert_frame_equal(result, expected) + # groupby expanding + + result = ( + roll_frame.groupby("A") + .expanding() + .apply(func, engine="numba", raw=True, kwargs={"a": 1}) + ) + expected = roll_frame.groupby("A").expanding().sum() + 1 + tm.assert_frame_equal(result, expected) + + result = ( + roll_frame.groupby("A") + .expanding() + .apply(func, engine="numba", raw=True, args=(1,)) + ) + tm.assert_frame_equal(result, expected) + def test_numba_min_periods(self): # GH 58868 def last_row(x): From 93925ba7b8ad39edab3e64233826a46b1bf722dd Mon Sep 17 00:00:00 2001 From: auderson <48577571+auderson@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:15:02 +0800 Subject: [PATCH 25/26] Update doc/source/whatsnew/v3.0.0.rst Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> --- doc/source/whatsnew/v3.0.0.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 89f84a93d6dd7..de9b2c3668e39 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -64,7 +64,6 @@ Other enhancements - Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) - .. --------------------------------------------------------------------------- .. _whatsnew_300.notable_bug_fixes: From 09bdae0d0c378ac4bcb23da2421591b6a4438f1d Mon Sep 17 00:00:00 2001 From: auderson <48577571+auderson@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:15:09 +0800 Subject: [PATCH 26/26] Update doc/source/whatsnew/v3.0.0.rst Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> --- doc/source/whatsnew/v3.0.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index de9b2c3668e39..967f836a90d1f 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -54,7 +54,7 @@ Other enhancements - :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`) - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`) -- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply`: when using numba engine in these apply methods, positional arguments now can be passed as kwargs (:issue:`58995`) +- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`) - :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`) - :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`) - :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`)