|
45 | 45 | ABCSeries,
|
46 | 46 | )
|
47 | 47 |
|
48 |
| -from pandas.core._numba.executor import generate_apply_looper |
49 | 48 | import pandas.core.common as com
|
50 | 49 | from pandas.core.construction import ensure_wrapped_if_datetimelike
|
| 50 | +from pandas.core._numba.executor import generate_apply_looper |
51 | 51 | from pandas.core.util.numba_ import (
|
52 | 52 | get_jit_arguments,
|
53 | 53 | prepare_function_arguments,
|
@@ -178,6 +178,57 @@ def apply(
|
178 | 178 | """
|
179 | 179 |
|
180 | 180 |
|
| 181 | +class NumbaExecutionEngine(BaseExecutionEngine): |
| 182 | + """ |
| 183 | + Numba-based execution engine for pandas apply and map operations. |
| 184 | + """ |
| 185 | + |
| 186 | + @staticmethod |
| 187 | + def map( |
| 188 | + data: np.ndarray | Series | DataFrame, |
| 189 | + func, |
| 190 | + args: tuple, |
| 191 | + kwargs: dict, |
| 192 | + engine_kwargs: dict | None, |
| 193 | + skip_na: bool, |
| 194 | + ): |
| 195 | + """ |
| 196 | + Elementwise map for the Numba engine. Currently not supported. |
| 197 | + """ |
| 198 | + raise NotImplementedError("Numba map is not implemented yet.") |
| 199 | + |
| 200 | + @staticmethod |
| 201 | + def apply( |
| 202 | + data: np.ndarray | Series | DataFrame, |
| 203 | + func, |
| 204 | + args: tuple, |
| 205 | + kwargs: dict, |
| 206 | + engine_kwargs: dict | None, |
| 207 | + axis: int | str, |
| 208 | + ): |
| 209 | + """ |
| 210 | + Apply `func` along the given axis using Numba. |
| 211 | + """ |
| 212 | + |
| 213 | + looper_args, looper_kwargs = prepare_function_arguments( |
| 214 | + func, # type: ignore[arg-type] |
| 215 | + args, |
| 216 | + kwargs, |
| 217 | + num_required_args=1, |
| 218 | + ) |
| 219 | + # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has |
| 220 | + # incompatible type "Callable[..., Any] | str | list[Callable |
| 221 | + # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | |
| 222 | + # list[Callable[..., Any] | str]]"; expected "Hashable" |
| 223 | + nb_looper = generate_apply_looper( |
| 224 | + func, # type: ignore[arg-type] |
| 225 | + **get_jit_arguments(engine_kwargs) |
| 226 | + ) |
| 227 | + result = nb_looper(data, axis, *looper_args) |
| 228 | + # If we made the result 2-D, squeeze it back to 1-D |
| 229 | + return np.squeeze(result) |
| 230 | + |
| 231 | + |
181 | 232 | def frame_apply(
|
182 | 233 | obj: DataFrame,
|
183 | 234 | func: AggFuncType,
|
@@ -1094,23 +1145,15 @@ def wrapper(*args, **kwargs):
|
1094 | 1145 | return wrapper
|
1095 | 1146 |
|
1096 | 1147 | if engine == "numba":
|
1097 |
| - args, kwargs = prepare_function_arguments( |
1098 |
| - self.func, # type: ignore[arg-type] |
| 1148 | + engine_obj = NumbaExecutionEngine() |
| 1149 | + result = engine_obj.apply( |
| 1150 | + self.values, |
| 1151 | + self.func, |
1099 | 1152 | self.args,
|
1100 | 1153 | self.kwargs,
|
1101 |
| - num_required_args=1, |
1102 |
| - ) |
1103 |
| - # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has |
1104 |
| - # incompatible type "Callable[..., Any] | str | list[Callable |
1105 |
| - # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | |
1106 |
| - # list[Callable[..., Any] | str]]"; expected "Hashable" |
1107 |
| - nb_looper = generate_apply_looper( |
1108 |
| - self.func, # type: ignore[arg-type] |
1109 |
| - **get_jit_arguments(engine_kwargs), |
| 1154 | + engine_kwargs, |
| 1155 | + self.axis, |
1110 | 1156 | )
|
1111 |
| - result = nb_looper(self.values, self.axis, *args) |
1112 |
| - # If we made the result 2-D, squeeze it back to 1-D |
1113 |
| - result = np.squeeze(result) |
1114 | 1157 | else:
|
1115 | 1158 | result = np.apply_along_axis(
|
1116 | 1159 | wrap_function(self.func),
|
|
0 commit comments