Skip to content

Commit bba53fc

Browse files
authored
ENH: Add method='table' for EWM.mean (#42273)
1 parent b8c8aca commit bba53fc

File tree

6 files changed

+127
-7
lines changed

6 files changed

+127
-7
lines changed

asv_bench/benchmarks/rolling.py

+3
Original file line numberDiff line numberDiff line change
@@ -296,5 +296,8 @@ def time_apply(self, method):
296296
table_method_func, raw=True, engine="numba"
297297
)
298298

299+
def time_ewm_mean(self, method):
300+
self.df.ewm(1, method=method).mean(engine="numba")
301+
299302

300303
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ For example:
240240
Other enhancements
241241
^^^^^^^^^^^^^^^^^^
242242

243-
- :meth:`DataFrame.rolling`, :meth:`Series.rolling`, :meth:`DataFrame.expanding`, and :meth:`Series.expanding` now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`15095`, :issue:`38995`)
243+
- :meth:`DataFrame.rolling`, :meth:`Series.rolling`, :meth:`DataFrame.expanding`, :meth:`Series.ewm`, :meth:`DataFrame.ewm`, :meth:`Series.expanding` now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`15095`, :issue:`38995`, :issue:`42273`)
244244
- :class:`.ExponentialMovingWindow` now support a ``online`` method that can perform ``mean`` calculations in an online fashion. See :ref:`Window Overview <window.overview>` (:issue:`41673`)
245245
- Added :meth:`MultiIndex.dtypes` (:issue:`37062`)
246246
- Added ``end`` and ``end_day`` options for the ``origin`` argument in :meth:`DataFrame.resample` (:issue:`37804`)

pandas/core/generic.py

+2
Original file line numberDiff line numberDiff line change
@@ -10905,6 +10905,7 @@ def ewm(
1090510905
ignore_na: bool_t = False,
1090610906
axis: Axis = 0,
1090710907
times: str | np.ndarray | FrameOrSeries | None = None,
10908+
method: str = "single",
1090810909
) -> ExponentialMovingWindow:
1090910910
axis = self._get_axis_number(axis)
1091010911
# error: Value of type variable "FrameOrSeries" of "ExponentialMovingWindow"
@@ -10920,6 +10921,7 @@ def ewm(
1092010921
ignore_na=ignore_na,
1092110922
axis=axis,
1092210923
times=times,
10924+
method=method,
1092310925
)
1092410926

1092510927
# ----------------------------------------------------------------------

pandas/core/window/ewm.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@
4040
ExponentialMovingWindowIndexer,
4141
GroupbyIndexer,
4242
)
43-
from pandas.core.window.numba_ import generate_numba_ewma_func
43+
from pandas.core.window.numba_ import (
44+
generate_ewma_numba_table_func,
45+
generate_numba_ewma_func,
46+
)
4447
from pandas.core.window.online import (
4548
EWMMeanState,
4649
generate_online_numba_ewma_func,
@@ -200,6 +203,16 @@ class ExponentialMovingWindow(BaseWindow):
200203
If 1-D array like, a sequence with the same shape as the observations.
201204
202205
Only applicable to ``mean()``.
206+
method : str {'single', 'table'}, default 'single'
207+
Execute the rolling operation per single column or row (``'single'``)
208+
or over the entire object (``'table'``).
209+
210+
This argument is only implemented when specifying ``engine='numba'``
211+
in the method call.
212+
213+
Only applicable to ``mean()``
214+
215+
.. versionadded:: 1.3.0
203216
204217
Returns
205218
-------
@@ -258,6 +271,7 @@ class ExponentialMovingWindow(BaseWindow):
258271
"ignore_na",
259272
"axis",
260273
"times",
274+
"method",
261275
]
262276

263277
def __init__(
@@ -272,6 +286,7 @@ def __init__(
272286
ignore_na: bool = False,
273287
axis: Axis = 0,
274288
times: str | np.ndarray | FrameOrSeries | None = None,
289+
method: str = "single",
275290
*,
276291
selection=None,
277292
):
@@ -281,7 +296,7 @@ def __init__(
281296
on=None,
282297
center=False,
283298
closed=None,
284-
method="single",
299+
method=method,
285300
axis=axis,
286301
selection=selection,
287302
)
@@ -437,12 +452,19 @@ def aggregate(self, func, *args, **kwargs):
437452
)
438453
def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
439454
if maybe_use_numba(engine):
440-
ewma_func = generate_numba_ewma_func(
441-
engine_kwargs, self._com, self.adjust, self.ignore_na, self._deltas
442-
)
455+
if self.method == "single":
456+
ewma_func = generate_numba_ewma_func(
457+
engine_kwargs, self._com, self.adjust, self.ignore_na, self._deltas
458+
)
459+
numba_cache_key = (lambda x: x, "ewma")
460+
else:
461+
ewma_func = generate_ewma_numba_table_func(
462+
engine_kwargs, self._com, self.adjust, self.ignore_na, self._deltas
463+
)
464+
numba_cache_key = (lambda x: x, "ewma_table")
443465
return self._apply(
444466
ewma_func,
445-
numba_cache_key=(lambda x: x, "ewma"),
467+
numba_cache_key=numba_cache_key,
446468
)
447469
elif engine in ("cython", None):
448470
if engine_kwargs is not None:

pandas/core/window/numba_.py

+80
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,83 @@ def nan_agg_with_axis(table):
248248
return result
249249

250250
return nan_agg_with_axis
251+
252+
253+
def generate_ewma_numba_table_func(
254+
engine_kwargs: dict[str, bool] | None,
255+
com: float,
256+
adjust: bool,
257+
ignore_na: bool,
258+
deltas: np.ndarray,
259+
):
260+
"""
261+
Generate a numba jitted ewma function applied table wise specified
262+
by values from engine_kwargs.
263+
264+
Parameters
265+
----------
266+
engine_kwargs : dict
267+
dictionary of arguments to be passed into numba.jit
268+
com : float
269+
adjust : bool
270+
ignore_na : bool
271+
deltas : numpy.ndarray
272+
273+
Returns
274+
-------
275+
Numba function
276+
"""
277+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
278+
279+
cache_key = (lambda x: x, "ewma_table")
280+
if cache_key in NUMBA_FUNC_CACHE:
281+
return NUMBA_FUNC_CACHE[cache_key]
282+
283+
numba = import_optional_dependency("numba")
284+
285+
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
286+
def ewma_table(
287+
values: np.ndarray,
288+
begin: np.ndarray,
289+
end: np.ndarray,
290+
minimum_periods: int,
291+
) -> np.ndarray:
292+
alpha = 1.0 / (1.0 + com)
293+
old_wt_factor = 1.0 - alpha
294+
new_wt = 1.0 if adjust else alpha
295+
old_wt = np.ones(values.shape[0])
296+
297+
result = np.empty(values.shape)
298+
weighted_avg = values[0].copy()
299+
nobs = (~np.isnan(weighted_avg)).astype(np.int64)
300+
result[0] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
301+
302+
for i in range(1, len(values)):
303+
cur = values[i]
304+
is_observations = ~np.isnan(cur)
305+
nobs += is_observations.astype(np.int64)
306+
for j in numba.prange(len(cur)):
307+
if not np.isnan(weighted_avg[j]):
308+
if is_observations[j] or not ignore_na:
309+
310+
# note that len(deltas) = len(vals) - 1 and deltas[i] is to be
311+
# used in conjunction with vals[i+1]
312+
old_wt[j] *= old_wt_factor ** deltas[j - 1]
313+
if is_observations[j]:
314+
# avoid numerical errors on constant series
315+
if weighted_avg[j] != cur[j]:
316+
weighted_avg[j] = (
317+
(old_wt[j] * weighted_avg[j]) + (new_wt * cur[j])
318+
) / (old_wt[j] + new_wt)
319+
if adjust:
320+
old_wt[j] += new_wt
321+
else:
322+
old_wt[j] = 1.0
323+
elif is_observations[j]:
324+
weighted_avg[j] = cur[j]
325+
326+
result[i] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
327+
328+
return result
329+
330+
return ewma_table

pandas/tests/window/test_numba.py

+13
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,16 @@ def test_table_method_expanding_methods(
304304
engine_kwargs=engine_kwargs, engine="numba"
305305
)
306306
tm.assert_frame_equal(result, expected)
307+
308+
def test_table_method_ewm(self, axis, nogil, parallel, nopython):
309+
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
310+
311+
df = DataFrame(np.eye(3))
312+
313+
result = df.ewm(com=1, method="table", axis=axis).mean(
314+
engine_kwargs=engine_kwargs, engine="numba"
315+
)
316+
expected = df.ewm(com=1, method="single", axis=axis).mean(
317+
engine_kwargs=engine_kwargs, engine="numba"
318+
)
319+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)