diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py index 4d841228579ca..6bdb8d6b4d863 100644 --- a/pandas/core/groupby/numba_.py +++ b/pandas/core/groupby/numba_.py @@ -92,7 +92,7 @@ def generate_numba_agg_func( ------- Numba function """ - numba_func = jit_user_function(func, nopython, nogil, parallel) + numba_func = jit_user_function(func) if TYPE_CHECKING: import numba else: @@ -152,7 +152,7 @@ def generate_numba_transform_func( ------- Numba function """ - numba_func = jit_user_function(func, nopython, nogil, parallel) + numba_func = jit_user_function(func) if TYPE_CHECKING: import numba else: diff --git a/pandas/core/util/numba_.py b/pandas/core/util/numba_.py index be798e022ac6e..b8d489179338b 100644 --- a/pandas/core/util/numba_.py +++ b/pandas/core/util/numba_.py @@ -1,14 +1,11 @@ """Common utilities for Numba operations""" from __future__ import annotations -import types from typing import ( TYPE_CHECKING, Callable, ) -import numpy as np - from pandas.compat._optional import import_optional_dependency from pandas.errors import NumbaUtilError @@ -63,27 +60,20 @@ def get_jit_arguments( return {"nopython": nopython, "nogil": nogil, "parallel": parallel} -def jit_user_function( - func: Callable, nopython: bool, nogil: bool, parallel: bool -) -> Callable: +def jit_user_function(func: Callable) -> Callable: """ - JIT the user's function given the configurable arguments. + If user function is not jitted already, mark the user's function + as jitable. Parameters ---------- func : function user defined function - nopython : bool - nopython parameter for numba.JIT - nogil : bool - nogil parameter for numba.JIT - parallel : bool - parallel parameter for numba.JIT Returns ------- function - Numba JITed function + Numba JITed function, or function marked as JITable by numba """ if TYPE_CHECKING: import numba @@ -94,19 +84,6 @@ def jit_user_function( # Don't jit a user passed jitted function numba_func = func else: - - @numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel) - def numba_func(data, *_args): - if getattr(np, func.__name__, False) is func or isinstance( - func, types.BuiltinFunctionType - ): - jf = func - else: - jf = numba.jit(func, nopython=nopython, nogil=nogil) - - def impl(data, *_args): - return jf(data, *_args) - - return impl + numba_func = numba.extending.register_jitable(func) return numba_func diff --git a/pandas/core/window/numba_.py b/pandas/core/window/numba_.py index 14775cc7f457e..71c7327665b65 100644 --- a/pandas/core/window/numba_.py +++ b/pandas/core/window/numba_.py @@ -48,7 +48,7 @@ def generate_numba_apply_func( ------- Numba function """ - numba_func = jit_user_function(func, nopython, nogil, parallel) + numba_func = jit_user_function(func) if TYPE_CHECKING: import numba else: @@ -207,7 +207,7 @@ def generate_numba_table_func( ------- Numba function """ - numba_func = jit_user_function(func, nopython, nogil, parallel) + numba_func = jit_user_function(func) if TYPE_CHECKING: import numba else: