Skip to content

Commit 2ed97c0

Browse files
authored
ENH: Add method='table' for EWM.mean (#42339)
1 parent cd75f49 commit 2ed97c0

File tree

6 files changed

+127
-6
lines changed

6 files changed

+127
-6
lines changed

asv_bench/benchmarks/rolling.py

+3
Original file line numberDiff line numberDiff line change
@@ -296,5 +296,8 @@ def time_apply(self, method):
296296
table_method_func, raw=True, engine="numba"
297297
)
298298

299+
def time_ewm_mean(self, method):
300+
self.df.ewm(1, method=method).mean(engine="numba")
301+
299302

300303
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ enhancement2
3030
Other enhancements
3131
^^^^^^^^^^^^^^^^^^
3232
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
33+
- :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`42273`)
3334
-
3435

3536
.. ---------------------------------------------------------------------------

pandas/core/generic.py

+2
Original file line numberDiff line numberDiff line change
@@ -10846,6 +10846,7 @@ def ewm(
1084610846
ignore_na: bool_t = False,
1084710847
axis: Axis = 0,
1084810848
times: str | np.ndarray | FrameOrSeries | None = None,
10849+
method: str = "single",
1084910850
) -> ExponentialMovingWindow:
1085010851
axis = self._get_axis_number(axis)
1085110852
# error: Value of type variable "FrameOrSeries" of "ExponentialMovingWindow"
@@ -10861,6 +10862,7 @@ def ewm(
1086110862
ignore_na=ignore_na,
1086210863
axis=axis,
1086310864
times=times,
10865+
method=method,
1086410866
)
1086510867

1086610868
# ----------------------------------------------------------------------

pandas/core/window/ewm.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@
4444
ExponentialMovingWindowIndexer,
4545
GroupbyIndexer,
4646
)
47-
from pandas.core.window.numba_ import generate_numba_ewma_func
47+
from pandas.core.window.numba_ import (
48+
generate_ewma_numba_table_func,
49+
generate_numba_ewma_func,
50+
)
4851
from pandas.core.window.online import (
4952
EWMMeanState,
5053
generate_online_numba_ewma_func,
@@ -204,6 +207,16 @@ class ExponentialMovingWindow(BaseWindow):
204207
If 1-D array like, a sequence with the same shape as the observations.
205208
206209
Only applicable to ``mean()``.
210+
method : str {'single', 'table'}, default 'single'
211+
Execute the rolling operation per single column or row (``'single'``)
212+
or over the entire object (``'table'``).
213+
214+
This argument is only implemented when specifying ``engine='numba'``
215+
in the method call.
216+
217+
Only applicable to ``mean()``
218+
219+
.. versionadded:: 1.4.0
207220
208221
Returns
209222
-------
@@ -262,6 +275,7 @@ class ExponentialMovingWindow(BaseWindow):
262275
"ignore_na",
263276
"axis",
264277
"times",
278+
"method",
265279
]
266280

267281
def __init__(
@@ -276,6 +290,7 @@ def __init__(
276290
ignore_na: bool = False,
277291
axis: Axis = 0,
278292
times: str | np.ndarray | FrameOrSeries | None = None,
293+
method: str = "single",
279294
*,
280295
selection=None,
281296
):
@@ -285,7 +300,7 @@ def __init__(
285300
on=None,
286301
center=False,
287302
closed=None,
288-
method="single",
303+
method=method,
289304
axis=axis,
290305
selection=selection,
291306
)
@@ -441,12 +456,19 @@ def aggregate(self, func, *args, **kwargs):
441456
)
442457
def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
443458
if maybe_use_numba(engine):
444-
ewma_func = generate_numba_ewma_func(
445-
engine_kwargs, self._com, self.adjust, self.ignore_na, self._deltas
446-
)
459+
if self.method == "single":
460+
ewma_func = generate_numba_ewma_func(
461+
engine_kwargs, self._com, self.adjust, self.ignore_na, self._deltas
462+
)
463+
numba_cache_key = (lambda x: x, "ewma")
464+
else:
465+
ewma_func = generate_ewma_numba_table_func(
466+
engine_kwargs, self._com, self.adjust, self.ignore_na, self._deltas
467+
)
468+
numba_cache_key = (lambda x: x, "ewma_table")
447469
return self._apply(
448470
ewma_func,
449-
numba_cache_key=(lambda x: x, "ewma"),
471+
numba_cache_key=numba_cache_key,
450472
)
451473
elif engine in ("cython", None):
452474
if engine_kwargs is not None:

pandas/core/window/numba_.py

+79
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,82 @@ def nan_agg_with_axis(table):
250250
return result
251251

252252
return nan_agg_with_axis
253+
254+
255+
def generate_ewma_numba_table_func(
256+
engine_kwargs: dict[str, bool] | None,
257+
com: float,
258+
adjust: bool,
259+
ignore_na: bool,
260+
deltas: np.ndarray,
261+
):
262+
"""
263+
Generate a numba jitted ewma function applied table wise specified
264+
by values from engine_kwargs.
265+
266+
Parameters
267+
----------
268+
engine_kwargs : dict
269+
dictionary of arguments to be passed into numba.jit
270+
com : float
271+
adjust : bool
272+
ignore_na : bool
273+
deltas : numpy.ndarray
274+
275+
Returns
276+
-------
277+
Numba function
278+
"""
279+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
280+
281+
cache_key = (lambda x: x, "ewma_table")
282+
if cache_key in NUMBA_FUNC_CACHE:
283+
return NUMBA_FUNC_CACHE[cache_key]
284+
285+
numba = import_optional_dependency("numba")
286+
287+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
288+
def ewma_table(
289+
values: np.ndarray,
290+
begin: np.ndarray,
291+
end: np.ndarray,
292+
minimum_periods: int,
293+
) -> np.ndarray:
294+
alpha = 1.0 / (1.0 + com)
295+
old_wt_factor = 1.0 - alpha
296+
new_wt = 1.0 if adjust else alpha
297+
old_wt = np.ones(values.shape[1])
298+
299+
result = np.empty(values.shape)
300+
weighted_avg = values[0].copy()
301+
nobs = (~np.isnan(weighted_avg)).astype(np.int64)
302+
result[0] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
303+
for i in range(1, len(values)):
304+
cur = values[i]
305+
is_observations = ~np.isnan(cur)
306+
nobs += is_observations.astype(np.int64)
307+
for j in numba.prange(len(cur)):
308+
if not np.isnan(weighted_avg[j]):
309+
if is_observations[j] or not ignore_na:
310+
311+
# note that len(deltas) = len(vals) - 1 and deltas[i] is to be
312+
# used in conjunction with vals[i+1]
313+
old_wt[j] *= old_wt_factor ** deltas[i - 1]
314+
if is_observations[j]:
315+
# avoid numerical errors on constant series
316+
if weighted_avg[j] != cur[j]:
317+
weighted_avg[j] = (
318+
(old_wt[j] * weighted_avg[j]) + (new_wt * cur[j])
319+
) / (old_wt[j] + new_wt)
320+
if adjust:
321+
old_wt[j] += new_wt
322+
else:
323+
old_wt[j] = 1.0
324+
elif is_observations[j]:
325+
weighted_avg[j] = cur[j]
326+
327+
result[i] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
328+
329+
return result
330+
331+
return ewma_table

pandas/tests/window/test_numba.py

+14
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,17 @@ def test_table_method_expanding_methods(
332332
engine_kwargs=engine_kwargs, engine="numba"
333333
)
334334
tm.assert_frame_equal(result, expected)
335+
336+
@pytest.mark.parametrize("data", [np.eye(3), np.ones((2, 3)), np.ones((3, 2))])
337+
def test_table_method_ewm(self, data, axis, nogil, parallel, nopython):
338+
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
339+
340+
df = DataFrame(data)
341+
342+
result = df.ewm(com=1, method="table", axis=axis).mean(
343+
engine_kwargs=engine_kwargs, engine="numba"
344+
)
345+
expected = df.ewm(com=1, method="single", axis=axis).mean(
346+
engine_kwargs=engine_kwargs, engine="numba"
347+
)
348+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)