Skip to content

Commit f74b988

Browse files
lithomas1pre-commit-ci[bot]
authored andcommitted
ENH: numba engine in df.apply (pandas-dev#54666)
* ENH: numba engine in df.apply * fixes * more fixes * try to fix * address code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * go for green * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update type --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 01ea1b1 commit f74b988

File tree

5 files changed

+155
-18
lines changed

5 files changed

+155
-18
lines changed

doc/source/whatsnew/v2.2.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ enhancement2
2828

2929
Other enhancements
3030
^^^^^^^^^^^^^^^^^^
31-
-
31+
- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`)
3232
-
3333

3434
.. ---------------------------------------------------------------------------

pandas/core/_numba/executor.py

+39
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,45 @@
1515
from pandas.compat._optional import import_optional_dependency
1616

1717

18+
@functools.cache
19+
def generate_apply_looper(func, nopython=True, nogil=True, parallel=False):
20+
if TYPE_CHECKING:
21+
import numba
22+
else:
23+
numba = import_optional_dependency("numba")
24+
nb_compat_func = numba.extending.register_jitable(func)
25+
26+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
27+
def nb_looper(values, axis):
28+
# Operate on the first row/col in order to get
29+
# the output shape
30+
if axis == 0:
31+
first_elem = values[:, 0]
32+
dim0 = values.shape[1]
33+
else:
34+
first_elem = values[0]
35+
dim0 = values.shape[0]
36+
res0 = nb_compat_func(first_elem)
37+
# Use np.asarray to get shape for
38+
# https://github.com/numba/numba/issues/4202#issuecomment-1185981507
39+
buf_shape = (dim0,) + np.atleast_1d(np.asarray(res0)).shape
40+
if axis == 0:
41+
buf_shape = buf_shape[::-1]
42+
buff = np.empty(buf_shape)
43+
44+
if axis == 1:
45+
buff[0] = res0
46+
for i in numba.prange(1, values.shape[0]):
47+
buff[i] = nb_compat_func(values[i])
48+
else:
49+
buff[:, 0] = res0
50+
for j in numba.prange(1, values.shape[1]):
51+
buff[:, j] = nb_compat_func(values[:, j])
52+
return buff
53+
54+
return nb_looper
55+
56+
1857
@functools.cache
1958
def make_looper(func, result_dtype, is_grouped_kernel, nopython, nogil, parallel):
2059
if TYPE_CHECKING:

pandas/core/apply.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
ABCSeries,
5050
)
5151

52+
from pandas.core._numba.executor import generate_apply_looper
5253
import pandas.core.common as com
5354
from pandas.core.construction import ensure_wrapped_if_datetimelike
5455

@@ -80,6 +81,8 @@ def frame_apply(
8081
raw: bool = False,
8182
result_type: str | None = None,
8283
by_row: Literal[False, "compat"] = "compat",
84+
engine: str = "python",
85+
engine_kwargs: dict[str, bool] | None = None,
8386
args=None,
8487
kwargs=None,
8588
) -> FrameApply:
@@ -100,6 +103,8 @@ def frame_apply(
100103
raw=raw,
101104
result_type=result_type,
102105
by_row=by_row,
106+
engine=engine,
107+
engine_kwargs=engine_kwargs,
103108
args=args,
104109
kwargs=kwargs,
105110
)
@@ -756,11 +761,15 @@ def __init__(
756761
result_type: str | None,
757762
*,
758763
by_row: Literal[False, "compat"] = False,
764+
engine: str = "python",
765+
engine_kwargs: dict[str, bool] | None = None,
759766
args,
760767
kwargs,
761768
) -> None:
762769
if by_row is not False and by_row != "compat":
763770
raise ValueError(f"by_row={by_row} not allowed")
771+
self.engine = engine
772+
self.engine_kwargs = engine_kwargs
764773
super().__init__(
765774
obj, func, raw, result_type, by_row=by_row, args=args, kwargs=kwargs
766775
)
@@ -805,6 +814,12 @@ def values(self):
805814

806815
def apply(self) -> DataFrame | Series:
807816
"""compute the results"""
817+
818+
if self.engine == "numba" and not self.raw:
819+
raise ValueError(
820+
"The numba engine in DataFrame.apply can only be used when raw=True"
821+
)
822+
808823
# dispatch to handle list-like or dict-like
809824
if is_list_like(self.func):
810825
return self.apply_list_or_dict_like()
@@ -834,7 +849,7 @@ def apply(self) -> DataFrame | Series:
834849

835850
# raw
836851
elif self.raw:
837-
return self.apply_raw()
852+
return self.apply_raw(engine=self.engine, engine_kwargs=self.engine_kwargs)
838853

839854
return self.apply_standard()
840855

@@ -907,7 +922,7 @@ def apply_empty_result(self):
907922
else:
908923
return self.obj.copy()
909924

910-
def apply_raw(self):
925+
def apply_raw(self, engine="python", engine_kwargs=None):
911926
"""apply to the values as a numpy array"""
912927

913928
def wrap_function(func):
@@ -925,7 +940,23 @@ def wrapper(*args, **kwargs):
925940

926941
return wrapper
927942

928-
result = np.apply_along_axis(wrap_function(self.func), self.axis, self.values)
943+
if engine == "numba":
944+
engine_kwargs = {} if engine_kwargs is None else engine_kwargs
945+
946+
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
947+
# incompatible type "Callable[..., Any] | str | list[Callable
948+
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
949+
# list[Callable[..., Any] | str]]"; expected "Hashable"
950+
nb_looper = generate_apply_looper(
951+
self.func, **engine_kwargs # type: ignore[arg-type]
952+
)
953+
result = nb_looper(self.values, self.axis)
954+
# If we made the result 2-D, squeeze it back to 1-D
955+
result = np.squeeze(result)
956+
else:
957+
result = np.apply_along_axis(
958+
wrap_function(self.func), self.axis, self.values
959+
)
929960

930961
# TODO: mixed type case
931962
if result.ndim == 2:

pandas/core/frame.py

+33
Original file line numberDiff line numberDiff line change
@@ -9925,6 +9925,8 @@ def apply(
99259925
result_type: Literal["expand", "reduce", "broadcast"] | None = None,
99269926
args=(),
99279927
by_row: Literal[False, "compat"] = "compat",
9928+
engine: Literal["python", "numba"] = "python",
9929+
engine_kwargs: dict[str, bool] | None = None,
99289930
**kwargs,
99299931
):
99309932
"""
@@ -9984,6 +9986,35 @@ def apply(
99849986
If False, the funcs will be passed the whole Series at once.
99859987
99869988
.. versionadded:: 2.1.0
9989+
9990+
engine : {'python', 'numba'}, default 'python'
9991+
Choose between the python (default) engine or the numba engine in apply.
9992+
9993+
The numba engine will attempt to JIT compile the passed function,
9994+
which may result in speedups for large DataFrames.
9995+
It also supports the following engine_kwargs :
9996+
9997+
- nopython (compile the function in nopython mode)
9998+
- nogil (release the GIL inside the JIT compiled function)
9999+
- parallel (try to apply the function in parallel over the DataFrame)
10000+
10001+
Note: The numba compiler only supports a subset of
10002+
valid Python/numpy operations.
10003+
10004+
Please read more about the `supported python features
10005+
<https://numba.pydata.org/numba-doc/dev/reference/pysupported.html>`_
10006+
and `supported numpy features
10007+
<https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html>`_
10008+
in numba to learn what you can or cannot use in the passed function.
10009+
10010+
As of right now, the numba engine can only be used with raw=True.
10011+
10012+
.. versionadded:: 2.2.0
10013+
10014+
engine_kwargs : dict
10015+
Pass keyword arguments to the engine.
10016+
This is currently only used by the numba engine,
10017+
see the documentation for the engine argument for more information.
998710018
**kwargs
998810019
Additional keyword arguments to pass as keywords arguments to
998910020
`func`.
@@ -10084,6 +10115,8 @@ def apply(
1008410115
raw=raw,
1008510116
result_type=result_type,
1008610117
by_row=by_row,
10118+
engine=engine,
10119+
engine_kwargs=engine_kwargs,
1008710120
args=args,
1008810121
kwargs=kwargs,
1008910122
)

pandas/tests/apply/test_frame_apply.py

+48-14
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
from pandas.tests.frame.common import zip_frames
1919

2020

21+
@pytest.fixture(params=["python", "numba"])
22+
def engine(request):
23+
if request.param == "numba":
24+
pytest.importorskip("numba")
25+
return request.param
26+
27+
2128
def test_apply(float_frame):
2229
with np.errstate(all="ignore"):
2330
# ufunc
@@ -234,36 +241,42 @@ def test_apply_broadcast_series_lambda_func(int_frame_const_col):
234241

235242

236243
@pytest.mark.parametrize("axis", [0, 1])
237-
def test_apply_raw_float_frame(float_frame, axis):
244+
def test_apply_raw_float_frame(float_frame, axis, engine):
245+
if engine == "numba":
246+
pytest.skip("numba can't handle when UDF returns None.")
247+
238248
def _assert_raw(x):
239249
assert isinstance(x, np.ndarray)
240250
assert x.ndim == 1
241251

242-
float_frame.apply(_assert_raw, axis=axis, raw=True)
252+
float_frame.apply(_assert_raw, axis=axis, engine=engine, raw=True)
243253

244254

245255
@pytest.mark.parametrize("axis", [0, 1])
246-
def test_apply_raw_float_frame_lambda(float_frame, axis):
247-
result = float_frame.apply(np.mean, axis=axis, raw=True)
256+
def test_apply_raw_float_frame_lambda(float_frame, axis, engine):
257+
result = float_frame.apply(np.mean, axis=axis, engine=engine, raw=True)
248258
expected = float_frame.apply(lambda x: x.values.mean(), axis=axis)
249259
tm.assert_series_equal(result, expected)
250260

251261

252-
def test_apply_raw_float_frame_no_reduction(float_frame):
262+
def test_apply_raw_float_frame_no_reduction(float_frame, engine):
253263
# no reduction
254-
result = float_frame.apply(lambda x: x * 2, raw=True)
264+
result = float_frame.apply(lambda x: x * 2, engine=engine, raw=True)
255265
expected = float_frame * 2
256266
tm.assert_frame_equal(result, expected)
257267

258268

259269
@pytest.mark.parametrize("axis", [0, 1])
260-
def test_apply_raw_mixed_type_frame(mixed_type_frame, axis):
270+
def test_apply_raw_mixed_type_frame(mixed_type_frame, axis, engine):
271+
if engine == "numba":
272+
pytest.skip("isinstance check doesn't work with numba")
273+
261274
def _assert_raw(x):
262275
assert isinstance(x, np.ndarray)
263276
assert x.ndim == 1
264277

265278
# Mixed dtype (GH-32423)
266-
mixed_type_frame.apply(_assert_raw, axis=axis, raw=True)
279+
mixed_type_frame.apply(_assert_raw, axis=axis, engine=engine, raw=True)
267280

268281

269282
def test_apply_axis1(float_frame):
@@ -300,14 +313,20 @@ def test_apply_mixed_dtype_corner_indexing():
300313
)
301314
@pytest.mark.parametrize("raw", [True, False])
302315
@pytest.mark.parametrize("axis", [0, 1])
303-
def test_apply_empty_infer_type(ax, func, raw, axis):
316+
def test_apply_empty_infer_type(ax, func, raw, axis, engine, request):
304317
df = DataFrame(**{ax: ["a", "b", "c"]})
305318

306319
with np.errstate(all="ignore"):
307320
test_res = func(np.array([], dtype="f8"))
308321
is_reduction = not isinstance(test_res, np.ndarray)
309322

310-
result = df.apply(func, axis=axis, raw=raw)
323+
if engine == "numba" and raw is False:
324+
mark = pytest.mark.xfail(
325+
reason="numba engine only supports raw=True at the moment"
326+
)
327+
request.node.add_marker(mark)
328+
329+
result = df.apply(func, axis=axis, engine=engine, raw=raw)
311330
if is_reduction:
312331
agg_axis = df._get_agg_axis(axis)
313332
assert isinstance(result, Series)
@@ -607,8 +626,10 @@ def non_reducing_function(row):
607626
assert names == list(df.index)
608627

609628

610-
def test_apply_raw_function_runs_once():
629+
def test_apply_raw_function_runs_once(engine):
611630
# https://github.com/pandas-dev/pandas/issues/34506
631+
if engine == "numba":
632+
pytest.skip("appending to list outside of numba func is not supported")
612633

613634
df = DataFrame({"a": [1, 2, 3]})
614635
values = [] # Save row values function is applied to
@@ -623,7 +644,7 @@ def non_reducing_function(row):
623644
for func in [reducing_function, non_reducing_function]:
624645
del values[:]
625646

626-
df.apply(func, raw=True, axis=1)
647+
df.apply(func, engine=engine, raw=True, axis=1)
627648
assert values == list(df.a.to_list())
628649

629650

@@ -1449,10 +1470,12 @@ def test_apply_no_suffix_index():
14491470
tm.assert_frame_equal(result, expected)
14501471

14511472

1452-
def test_apply_raw_returns_string():
1473+
def test_apply_raw_returns_string(engine):
14531474
# https://github.com/pandas-dev/pandas/issues/35940
1475+
if engine == "numba":
1476+
pytest.skip("No object dtype support in numba")
14541477
df = DataFrame({"A": ["aa", "bbb"]})
1455-
result = df.apply(lambda x: x[0], axis=1, raw=True)
1478+
result = df.apply(lambda x: x[0], engine=engine, axis=1, raw=True)
14561479
expected = Series(["aa", "bbb"])
14571480
tm.assert_series_equal(result, expected)
14581481

@@ -1632,3 +1655,14 @@ def test_agg_dist_like_and_nonunique_columns():
16321655
result = df.agg({"A": "count"})
16331656
expected = df["A"].count()
16341657
tm.assert_series_equal(result, expected)
1658+
1659+
1660+
def test_numba_unsupported():
1661+
df = DataFrame(
1662+
{"A": [None, 2, 3], "B": [1.0, np.nan, 3.0], "C": ["foo", None, "bar"]}
1663+
)
1664+
with pytest.raises(
1665+
ValueError,
1666+
match="The numba engine in DataFrame.apply can only be used when raw=True",
1667+
):
1668+
df.apply(lambda x: x, engine="numba", raw=False)

0 commit comments

Comments
 (0)