Skip to content

Commit e4a96bb

Browse files
mroeschkejreback
authored andcommitted
ENH: Add engine keyword to expanding.apply to utilize Numba (#30937)
1 parent 171a1ed commit e4a96bb

File tree

7 files changed

+68
-18
lines changed

7 files changed

+68
-18
lines changed

asv_bench/asv.conf.json

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"matplotlib": [],
4444
"sqlalchemy": [],
4545
"scipy": [],
46+
"numba": [],
4647
"numexpr": [],
4748
"pytables": [null, ""], // platform dependent, see excludes below
4849
"tables": [null, ""],

asv_bench/benchmarks/rolling.py

+21
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,27 @@ def time_rolling(self, constructor, window, dtype, function, raw):
4444
self.roll.apply(function, raw=raw)
4545

4646

47+
class Engine:
48+
params = (
49+
["DataFrame", "Series"],
50+
["int", "float"],
51+
[np.sum, lambda x: np.sum(x) + 5],
52+
["cython", "numba"],
53+
)
54+
param_names = ["constructor", "dtype", "function", "engine"]
55+
56+
def setup(self, constructor, dtype, function, engine):
57+
N = 10 ** 3
58+
arr = (100 * np.random.random(N)).astype(dtype)
59+
self.data = getattr(pd, constructor)(arr)
60+
61+
def time_rolling_apply(self, constructor, dtype, function, engine):
62+
self.data.rolling(10).apply(function, raw=True, engine=engine)
63+
64+
def time_expanding_apply(self, constructor, dtype, function, engine):
65+
self.data.expanding().apply(function, raw=True, engine=engine)
66+
67+
4768
class ExpandingMethods:
4869

4970
params = (

doc/source/user_guide/computation.rst

+1
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ Numba will be applied in potentially two routines:
348348

349349
1. If ``func`` is a standard Python function, the engine will `JIT <http://numba.pydata.org/numba-doc/latest/user/overview.html>`__
350350
the passed function. ``func`` can also be a JITed function in which case the engine will not JIT the function again.
351+
351352
2. The engine will JIT the for loop where the apply function is applied to each window.
352353

353354
The ``engine_kwargs`` argument is a dictionary of keyword arguments that will be passed into the

doc/source/whatsnew/v1.0.0.rst

+6-6
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,14 @@ You can use the alias ``"boolean"`` as well.
159159
160160
.. _whatsnew_100.numba_rolling_apply:
161161

162-
Using Numba in ``rolling.apply``
163-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
162+
Using Numba in ``rolling.apply`` and ``expanding.apply``
163+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
164164

165-
We've added an ``engine`` keyword to :meth:`~core.window.rolling.Rolling.apply` that allows the user to execute the
166-
routine using `Numba <https://numba.pydata.org/>`__ instead of Cython. Using the Numba engine
167-
can yield significant performance gains if the apply function can operate on numpy arrays and
165+
We've added an ``engine`` keyword to :meth:`~core.window.rolling.Rolling.apply` and :meth:`~core.window.expanding.Expanding.apply`
166+
that allows the user to execute the routine using `Numba <https://numba.pydata.org/>`__ instead of Cython.
167+
Using the Numba engine can yield significant performance gains if the apply function can operate on numpy arrays and
168168
the data set is larger (1 million rows or greater). For more details, see
169-
:ref:`rolling apply documentation <stats.rolling_apply>` (:issue:`28987`)
169+
:ref:`rolling apply documentation <stats.rolling_apply>` (:issue:`28987`, :issue:`30936`)
170170

171171
.. _whatsnew_100.custom_window:
172172

pandas/core/window/expanding.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from textwrap import dedent
2+
from typing import Dict, Optional
23

34
from pandas.compat.numpy import function as nv
45
from pandas.util._decorators import Appender, Substitution
@@ -148,8 +149,23 @@ def count(self, **kwargs):
148149

149150
@Substitution(name="expanding")
150151
@Appender(_shared_docs["apply"])
151-
def apply(self, func, raw=False, args=(), kwargs={}):
152-
return super().apply(func, raw=raw, args=args, kwargs=kwargs)
152+
def apply(
153+
self,
154+
func,
155+
raw: bool = False,
156+
engine: str = "cython",
157+
engine_kwargs: Optional[Dict[str, bool]] = None,
158+
args=None,
159+
kwargs=None,
160+
):
161+
return super().apply(
162+
func,
163+
raw=raw,
164+
engine=engine,
165+
engine_kwargs=engine_kwargs,
166+
args=args,
167+
kwargs=kwargs,
168+
)
153169

154170
@Substitution(name="expanding")
155171
@Appender(_shared_docs["sum"])

pandas/core/window/rolling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1203,7 +1203,7 @@ def count(self):
12031203

12041204
_shared_docs["apply"] = dedent(
12051205
r"""
1206-
The %(name)s function's apply function.
1206+
Apply an arbitrary function to each %(name)s window.
12071207
12081208
Parameters
12091209
----------

pandas/tests/window/moments/test_moments_expanding.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@ class TestExpandingMomentsConsistency(ConsistencyBase):
1313
def setup_method(self, method):
1414
self._create_data()
1515

16-
def test_expanding_apply_args_kwargs(self, raw):
16+
def test_expanding_apply_args_kwargs(self, engine_and_raw):
1717
def mean_w_arg(x, const):
1818
return np.mean(x) + const
1919

20+
engine, raw = engine_and_raw
21+
2022
df = DataFrame(np.random.rand(20, 3))
2123

22-
expected = df.expanding().apply(np.mean, raw=raw) + 20.0
24+
expected = df.expanding().apply(np.mean, engine=engine, raw=raw) + 20.0
2325

24-
result = df.expanding().apply(mean_w_arg, raw=raw, args=(20,))
26+
result = df.expanding().apply(mean_w_arg, engine=engine, raw=raw, args=(20,))
2527
tm.assert_frame_equal(result, expected)
2628

2729
result = df.expanding().apply(mean_w_arg, raw=raw, kwargs={"const": 20})
@@ -190,26 +192,35 @@ def expanding_func(x, min_periods=1, center=False, axis=0):
190192
)
191193

192194
@pytest.mark.parametrize("has_min_periods", [True, False])
193-
def test_expanding_apply(self, raw, has_min_periods):
195+
def test_expanding_apply(self, engine_and_raw, has_min_periods):
196+
197+
engine, raw = engine_and_raw
198+
194199
def expanding_mean(x, min_periods=1):
195200

196201
exp = x.expanding(min_periods=min_periods)
197-
result = exp.apply(lambda x: x.mean(), raw=raw)
202+
result = exp.apply(lambda x: x.mean(), raw=raw, engine=engine)
198203
return result
199204

200205
# TODO(jreback), needed to add preserve_nan=False
201206
# here to make this pass
202207
self._check_expanding(expanding_mean, np.mean, preserve_nan=False)
203208
self._check_expanding_has_min_periods(expanding_mean, np.mean, has_min_periods)
204209

205-
def test_expanding_apply_empty_series(self, raw):
210+
def test_expanding_apply_empty_series(self, engine_and_raw):
211+
engine, raw = engine_and_raw
206212
ser = Series([], dtype=np.float64)
207-
tm.assert_series_equal(ser, ser.expanding().apply(lambda x: x.mean(), raw=raw))
213+
tm.assert_series_equal(
214+
ser, ser.expanding().apply(lambda x: x.mean(), raw=raw, engine=engine)
215+
)
208216

209-
def test_expanding_apply_min_periods_0(self, raw):
217+
def test_expanding_apply_min_periods_0(self, engine_and_raw):
210218
# GH 8080
219+
engine, raw = engine_and_raw
211220
s = Series([None, None, None])
212-
result = s.expanding(min_periods=0).apply(lambda x: len(x), raw=raw)
221+
result = s.expanding(min_periods=0).apply(
222+
lambda x: len(x), raw=raw, engine=engine
223+
)
213224
expected = Series([1.0, 2.0, 3.0])
214225
tm.assert_series_equal(result, expected)
215226

0 commit comments

Comments
 (0)