-
-
Notifications
You must be signed in to change notification settings - Fork 18.6k
Implemented NumbaExecutionEngine #61487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
aa42037
db9f3b0
4cb240d
97d9063
69e0e35
7365079
c605857
24a0615
b7a2ecb
545db65
6f4fb50
221cf7c
ed8dc7f
65b9d32
2703f86
347463e
77eb146
90f264f
f8f1166
bc2939b
176753b
cf3e392
a4bac18
ca91e89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
cast, | ||
) | ||
|
||
import numba | ||
import numpy as np | ||
|
||
from pandas._libs.internals import BlockValuesRefs | ||
|
@@ -178,6 +179,60 @@ def apply( | |
""" | ||
|
||
|
||
class NumbaExecutionEngine(BaseExecutionEngine): | ||
""" | ||
Numba-based execution engine for pandas apply and map operations. | ||
""" | ||
|
||
@staticmethod | ||
def map( | ||
data: np.ndarray | Series | DataFrame, | ||
func, | ||
args: tuple, | ||
kwargs: dict, | ||
decorator: Callable | None, | ||
skip_na: bool, | ||
): | ||
""" | ||
Elementwise map for the Numba engine. Currently not supported. | ||
""" | ||
raise NotImplementedError("Numba map is not implemented yet.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the error when users write something like |
||
|
||
@staticmethod | ||
def apply( | ||
data: np.ndarray | Series | DataFrame, | ||
func, | ||
args: tuple, | ||
kwargs: dict, | ||
decorator: Callable, | ||
axis: int | str, | ||
): | ||
""" | ||
Apply `func` along the given axis using Numba. | ||
""" | ||
engine_kwargs: dict[str, bool] | None = ( | ||
decorator if isinstance(decorator, dict) else None | ||
) | ||
|
||
looper_args, looper_kwargs = prepare_function_arguments( | ||
func, | ||
args, | ||
kwargs, | ||
num_required_args=1, | ||
) | ||
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has | ||
# incompatible type "Callable[..., Any] | str | list[Callable | ||
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | | ||
# list[Callable[..., Any] | str]]"; expected "Hashable" | ||
nb_looper = generate_apply_looper( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally wouldn't abbreviate to nb, it's not super clear imho. Just calling this |
||
func, | ||
**get_jit_arguments(engine_kwargs), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can make this simpler if you change What you are doing now is to extract the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also noticed that |
||
) | ||
result = nb_looper(data, axis, *looper_args) | ||
# If we made the result 2-D, squeeze it back to 1-D | ||
return np.squeeze(result) | ||
|
||
|
||
def frame_apply( | ||
obj: DataFrame, | ||
func: AggFuncType, | ||
|
@@ -1094,23 +1149,16 @@ def wrapper(*args, **kwargs): | |
return wrapper | ||
|
||
if engine == "numba": | ||
args, kwargs = prepare_function_arguments( | ||
self.func, # type: ignore[arg-type] | ||
if not hasattr(numba.jit, "__pandas_udf__"): | ||
numba.jit.__pandas_udf__ = NumbaExecutionEngine | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I think it'd be a simpler approach is to implement this logic here: https://github.com/pandas-dev/pandas/blob/main/pandas/core/frame.py#L10563 There, now we are considering two cases:
I would simplify that and just support engines with the engine interface
Since we want to support def apply(...):
if engine == "numba":
numba = import_optional_dependency("numba")
numba_jit = numba.jit(**engine_kwargs)
numba_jit.__pandas_udf__ = NumbaExecutionEngine From this point, all the code can pretend engine is going to be The challenge is that numba and the default engine share some code, and with this approach they'll be running independently. The default engine won't know anything about an When we move the default engine to a Does this approach makes sense to you? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes this makes sense thank you |
||
result = numba.jit.__pandas_udf__.apply( | ||
self.values, | ||
self.func, | ||
self.args, | ||
self.kwargs, | ||
num_required_args=1, | ||
) | ||
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has | ||
# incompatible type "Callable[..., Any] | str | list[Callable | ||
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | | ||
# list[Callable[..., Any] | str]]"; expected "Hashable" | ||
nb_looper = generate_apply_looper( | ||
self.func, # type: ignore[arg-type] | ||
**get_jit_arguments(engine_kwargs), | ||
engine_kwargs, | ||
self.axis, | ||
) | ||
result = nb_looper(self.values, self.axis, *args) | ||
# If we made the result 2-D, squeeze it back to 1-D | ||
result = np.squeeze(result) | ||
else: | ||
result = np.apply_along_axis( | ||
wrap_function(self.func), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numba
is not a required pandas dependency, I think this will makeimport pandas
raise anImportError
for users with no numba installed, even if they don't call any numba functionality.I think we've got an
import_optional_dependency
function that we call inside the functions where numba is used to avoid this problem.