Skip to content

Commit a0c7028

Browse files
authored
ENH: Numba engine for EWM.mean (#41267)
1 parent 94f20fa commit a0c7028

File tree

4 files changed

+66
-77
lines changed

4 files changed

+66
-77
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ Other enhancements
197197
- Add support for dict-like names in :class:`MultiIndex.set_names` and :class:`MultiIndex.rename` (:issue:`20421`)
198198
- :func:`pandas.read_excel` can now auto detect .xlsb files (:issue:`35416`)
199199
- :class:`pandas.ExcelWriter` now accepts an ``if_sheet_exists`` parameter to control the behaviour of append mode when writing to existing sheets (:issue:`40230`)
200-
- :meth:`.Rolling.sum`, :meth:`.Expanding.sum`, :meth:`.Rolling.mean`, :meth:`.Expanding.mean`, :meth:`.Rolling.median`, :meth:`.Expanding.median`, :meth:`.Rolling.max`, :meth:`.Expanding.max`, :meth:`.Rolling.min`, and :meth:`.Expanding.min` now support ``Numba`` execution with the ``engine`` keyword (:issue:`38895`)
200+
- :meth:`.Rolling.sum`, :meth:`.Expanding.sum`, :meth:`.Rolling.mean`, :meth:`.Expanding.mean`, :meth:`.ExponentialMovingWindow.mean`, :meth:`.Rolling.median`, :meth:`.Expanding.median`, :meth:`.Rolling.max`, :meth:`.Expanding.max`, :meth:`.Rolling.min`, and :meth:`.Expanding.min` now support ``Numba`` execution with the ``engine`` keyword (:issue:`38895`, :issue:`41267`)
201201
- :meth:`DataFrame.apply` can now accept NumPy unary operators as strings, e.g. ``df.apply("sqrt")``, which was already the case for :meth:`Series.apply` (:issue:`39116`)
202202
- :meth:`DataFrame.apply` can now accept non-callable DataFrame properties as strings, e.g. ``df.apply("size")``, which was already the case for :meth:`Series.apply` (:issue:`39116`)
203203
- :meth:`DataFrame.applymap` can now accept kwargs to pass on to func (:issue:`39987`)

pandas/core/window/ewm.py

+30-55
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,18 @@
2929
args_compat,
3030
create_section_header,
3131
kwargs_compat,
32+
numba_notes,
3233
template_header,
3334
template_returns,
3435
template_see_also,
36+
window_agg_numba_parameters,
3537
)
3638
from pandas.core.window.indexers import (
3739
BaseIndexer,
3840
ExponentialMovingWindowIndexer,
3941
GroupbyIndexer,
4042
)
41-
from pandas.core.window.numba_ import generate_numba_groupby_ewma_func
43+
from pandas.core.window.numba_ import generate_numba_ewma_func
4244
from pandas.core.window.rolling import (
4345
BaseWindow,
4446
BaseWindowGroupby,
@@ -372,26 +374,41 @@ def aggregate(self, func, *args, **kwargs):
372374
template_header,
373375
create_section_header("Parameters"),
374376
args_compat,
377+
window_agg_numba_parameters,
375378
kwargs_compat,
376379
create_section_header("Returns"),
377380
template_returns,
378381
create_section_header("See Also"),
379-
template_see_also[:-1],
382+
template_see_also,
383+
create_section_header("Notes"),
384+
numba_notes.replace("\n", "", 1),
380385
window_method="ewm",
381386
aggregation_description="(exponential weighted moment) mean",
382387
agg_method="mean",
383388
)
384-
def mean(self, *args, **kwargs):
385-
nv.validate_window_func("mean", args, kwargs)
386-
window_func = window_aggregations.ewma
387-
window_func = partial(
388-
window_func,
389-
com=self._com,
390-
adjust=self.adjust,
391-
ignore_na=self.ignore_na,
392-
deltas=self._deltas,
393-
)
394-
return self._apply(window_func)
389+
def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
390+
if maybe_use_numba(engine):
391+
ewma_func = generate_numba_ewma_func(
392+
engine_kwargs, self._com, self.adjust, self.ignore_na, self._deltas
393+
)
394+
return self._apply(
395+
ewma_func,
396+
numba_cache_key=(lambda x: x, "ewma"),
397+
)
398+
elif engine in ("cython", None):
399+
if engine_kwargs is not None:
400+
raise ValueError("cython engine does not accept engine_kwargs")
401+
nv.validate_window_func("mean", args, kwargs)
402+
window_func = partial(
403+
window_aggregations.ewma,
404+
com=self._com,
405+
adjust=self.adjust,
406+
ignore_na=self.ignore_na,
407+
deltas=self._deltas,
408+
)
409+
return self._apply(window_func)
410+
else:
411+
raise ValueError("engine must be either 'numba' or 'cython'")
395412

396413
@doc(
397414
template_header,
@@ -635,45 +652,3 @@ def _get_window_indexer(self) -> GroupbyIndexer:
635652
window_indexer=ExponentialMovingWindowIndexer,
636653
)
637654
return window_indexer
638-
639-
def mean(self, engine=None, engine_kwargs=None):
640-
"""
641-
Parameters
642-
----------
643-
engine : str, default None
644-
* ``'cython'`` : Runs mean through C-extensions from cython.
645-
* ``'numba'`` : Runs mean through JIT compiled code from numba.
646-
Only available when ``raw`` is set to ``True``.
647-
* ``None`` : Defaults to ``'cython'`` or globally setting
648-
``compute.use_numba``
649-
650-
.. versionadded:: 1.2.0
651-
652-
engine_kwargs : dict, default None
653-
* For ``'cython'`` engine, there are no accepted ``engine_kwargs``
654-
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil``
655-
and ``parallel`` dictionary keys. The values must either be ``True`` or
656-
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is
657-
``{'nopython': True, 'nogil': False, 'parallel': False}``.
658-
659-
.. versionadded:: 1.2.0
660-
661-
Returns
662-
-------
663-
Series or DataFrame
664-
Return type is determined by the caller.
665-
"""
666-
if maybe_use_numba(engine):
667-
groupby_ewma_func = generate_numba_groupby_ewma_func(
668-
engine_kwargs, self._com, self.adjust, self.ignore_na, self._deltas
669-
)
670-
return self._apply(
671-
groupby_ewma_func,
672-
numba_cache_key=(lambda x: x, "groupby_ewma"),
673-
)
674-
elif engine in ("cython", None):
675-
if engine_kwargs is not None:
676-
raise ValueError("cython engine does not accept engine_kwargs")
677-
return super().mean()
678-
else:
679-
raise ValueError("engine must be either 'numba' or 'cython'")

pandas/core/window/numba_.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ def roll_apply(
8080
return roll_apply
8181

8282

83-
def generate_numba_groupby_ewma_func(
83+
def generate_numba_ewma_func(
8484
engine_kwargs: Optional[Dict[str, bool]],
8585
com: float,
8686
adjust: bool,
8787
ignore_na: bool,
8888
deltas: np.ndarray,
8989
):
9090
"""
91-
Generate a numba jitted groupby ewma function specified by values
91+
Generate a numba jitted ewma function specified by values
9292
from engine_kwargs.
9393
9494
Parameters
@@ -106,30 +106,30 @@ def generate_numba_groupby_ewma_func(
106106
"""
107107
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
108108

109-
cache_key = (lambda x: x, "groupby_ewma")
109+
cache_key = (lambda x: x, "ewma")
110110
if cache_key in NUMBA_FUNC_CACHE:
111111
return NUMBA_FUNC_CACHE[cache_key]
112112

113113
numba = import_optional_dependency("numba")
114114

115115
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
116-
def groupby_ewma(
116+
def ewma(
117117
values: np.ndarray,
118118
begin: np.ndarray,
119119
end: np.ndarray,
120120
minimum_periods: int,
121121
) -> np.ndarray:
122122
result = np.empty(len(values))
123123
alpha = 1.0 / (1.0 + com)
124+
old_wt_factor = 1.0 - alpha
125+
new_wt = 1.0 if adjust else alpha
126+
124127
for i in numba.prange(len(begin)):
125128
start = begin[i]
126129
stop = end[i]
127130
window = values[start:stop]
128131
sub_result = np.empty(len(window))
129132

130-
old_wt_factor = 1.0 - alpha
131-
new_wt = 1.0 if adjust else alpha
132-
133133
weighted_avg = window[0]
134134
nobs = int(not np.isnan(weighted_avg))
135135
sub_result[0] = weighted_avg if nobs >= minimum_periods else np.nan
@@ -166,7 +166,7 @@ def groupby_ewma(
166166

167167
return result
168168

169-
return groupby_ewma
169+
return ewma
170170

171171

172172
def generate_numba_table_func(

pandas/tests/window/test_numba.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -123,30 +123,44 @@ def func_2(x):
123123

124124

125125
@td.skip_if_no("numba", "0.46.0")
126-
class TestGroupbyEWMMean:
127-
def test_invalid_engine(self):
126+
class TestEWMMean:
127+
@pytest.mark.parametrize(
128+
"grouper", [lambda x: x, lambda x: x.groupby("A")], ids=["None", "groupby"]
129+
)
130+
def test_invalid_engine(self, grouper):
128131
df = DataFrame({"A": ["a", "b", "a", "b"], "B": range(4)})
129132
with pytest.raises(ValueError, match="engine must be either"):
130-
df.groupby("A").ewm(com=1.0).mean(engine="foo")
133+
grouper(df).ewm(com=1.0).mean(engine="foo")
131134

132-
def test_invalid_engine_kwargs(self):
135+
@pytest.mark.parametrize(
136+
"grouper", [lambda x: x, lambda x: x.groupby("A")], ids=["None", "groupby"]
137+
)
138+
def test_invalid_engine_kwargs(self, grouper):
133139
df = DataFrame({"A": ["a", "b", "a", "b"], "B": range(4)})
134140
with pytest.raises(ValueError, match="cython engine does not"):
135-
df.groupby("A").ewm(com=1.0).mean(
141+
grouper(df).ewm(com=1.0).mean(
136142
engine="cython", engine_kwargs={"nopython": True}
137143
)
138144

139-
def test_cython_vs_numba(self, nogil, parallel, nopython, ignore_na, adjust):
145+
@pytest.mark.parametrize(
146+
"grouper", [lambda x: x, lambda x: x.groupby("A")], ids=["None", "groupby"]
147+
)
148+
def test_cython_vs_numba(
149+
self, grouper, nogil, parallel, nopython, ignore_na, adjust
150+
):
140151
df = DataFrame({"A": ["a", "b", "a", "b"], "B": range(4)})
141-
gb_ewm = df.groupby("A").ewm(com=1.0, adjust=adjust, ignore_na=ignore_na)
152+
ewm = grouper(df).ewm(com=1.0, adjust=adjust, ignore_na=ignore_na)
142153

143154
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
144-
result = gb_ewm.mean(engine="numba", engine_kwargs=engine_kwargs)
145-
expected = gb_ewm.mean(engine="cython")
155+
result = ewm.mean(engine="numba", engine_kwargs=engine_kwargs)
156+
expected = ewm.mean(engine="cython")
146157

147158
tm.assert_frame_equal(result, expected)
148159

149-
def test_cython_vs_numba_times(self, nogil, parallel, nopython, ignore_na):
160+
@pytest.mark.parametrize(
161+
"grouper", [lambda x: x, lambda x: x.groupby("A")], ids=["None", "groupby"]
162+
)
163+
def test_cython_vs_numba_times(self, grouper, nogil, parallel, nopython, ignore_na):
150164
# GH 40951
151165
halflife = "23 days"
152166
times = to_datetime(
@@ -160,13 +174,13 @@ def test_cython_vs_numba_times(self, nogil, parallel, nopython, ignore_na):
160174
]
161175
)
162176
df = DataFrame({"A": ["a", "b", "a", "b", "b", "a"], "B": [0, 0, 1, 1, 2, 2]})
163-
gb_ewm = df.groupby("A").ewm(
177+
ewm = grouper(df).ewm(
164178
halflife=halflife, adjust=True, ignore_na=ignore_na, times=times
165179
)
166180

167181
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
168-
result = gb_ewm.mean(engine="numba", engine_kwargs=engine_kwargs)
169-
expected = gb_ewm.mean(engine="cython")
182+
result = ewm.mean(engine="numba", engine_kwargs=engine_kwargs)
183+
expected = ewm.mean(engine="cython")
170184

171185
tm.assert_frame_equal(result, expected)
172186

0 commit comments

Comments
 (0)