Skip to content

Commit aa42037

Browse files
committed
Implemented NumbaExecutionEngine
1 parent 65bf9cd commit aa42037

File tree

1 file changed

+58
-15
lines changed

1 file changed

+58
-15
lines changed

pandas/core/apply.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@
4545
ABCSeries,
4646
)
4747

48-
from pandas.core._numba.executor import generate_apply_looper
4948
import pandas.core.common as com
5049
from pandas.core.construction import ensure_wrapped_if_datetimelike
50+
from pandas.core._numba.executor import generate_apply_looper
5151
from pandas.core.util.numba_ import (
5252
get_jit_arguments,
5353
prepare_function_arguments,
@@ -178,6 +178,57 @@ def apply(
178178
"""
179179

180180

181+
class NumbaExecutionEngine(BaseExecutionEngine):
182+
"""
183+
Numba-based execution engine for pandas apply and map operations.
184+
"""
185+
186+
@staticmethod
187+
def map(
188+
data: np.ndarray | Series | DataFrame,
189+
func,
190+
args: tuple,
191+
kwargs: dict,
192+
engine_kwargs: dict | None,
193+
skip_na: bool,
194+
):
195+
"""
196+
Elementwise map for the Numba engine. Currently not supported.
197+
"""
198+
raise NotImplementedError("Numba map is not implemented yet.")
199+
200+
@staticmethod
201+
def apply(
202+
data: np.ndarray | Series | DataFrame,
203+
func,
204+
args: tuple,
205+
kwargs: dict,
206+
engine_kwargs: dict | None,
207+
axis: int | str,
208+
):
209+
"""
210+
Apply `func` along the given axis using Numba.
211+
"""
212+
213+
looper_args, looper_kwargs = prepare_function_arguments(
214+
func, # type: ignore[arg-type]
215+
args,
216+
kwargs,
217+
num_required_args=1,
218+
)
219+
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
220+
# incompatible type "Callable[..., Any] | str | list[Callable
221+
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
222+
# list[Callable[..., Any] | str]]"; expected "Hashable"
223+
nb_looper = generate_apply_looper(
224+
func, # type: ignore[arg-type]
225+
**get_jit_arguments(engine_kwargs)
226+
)
227+
result = nb_looper(data, axis, *looper_args)
228+
# If we made the result 2-D, squeeze it back to 1-D
229+
return np.squeeze(result)
230+
231+
181232
def frame_apply(
182233
obj: DataFrame,
183234
func: AggFuncType,
@@ -1094,23 +1145,15 @@ def wrapper(*args, **kwargs):
10941145
return wrapper
10951146

10961147
if engine == "numba":
1097-
args, kwargs = prepare_function_arguments(
1098-
self.func, # type: ignore[arg-type]
1148+
engine_obj = NumbaExecutionEngine()
1149+
result = engine_obj.apply(
1150+
self.values,
1151+
self.func,
10991152
self.args,
11001153
self.kwargs,
1101-
num_required_args=1,
1102-
)
1103-
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
1104-
# incompatible type "Callable[..., Any] | str | list[Callable
1105-
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
1106-
# list[Callable[..., Any] | str]]"; expected "Hashable"
1107-
nb_looper = generate_apply_looper(
1108-
self.func, # type: ignore[arg-type]
1109-
**get_jit_arguments(engine_kwargs),
1154+
engine_kwargs,
1155+
self.axis,
11101156
)
1111-
result = nb_looper(self.values, self.axis, *args)
1112-
# If we made the result 2-D, squeeze it back to 1-D
1113-
result = np.squeeze(result)
11141157
else:
11151158
result = np.apply_along_axis(
11161159
wrap_function(self.func),

0 commit comments

Comments
 (0)