Skip to content

Commit 03f5066

Browse files
mroeschkeMatt Roeschke
and
Matt Roeschke
authored
CLN: EWMA cython code and function dispatch (#34636)
Co-authored-by: Matt Roeschke <[email protected]>
1 parent 24857a2 commit 03f5066

File tree

2 files changed

+26
-34
lines changed

2 files changed

+26
-34
lines changed

pandas/_libs/window/aggregations.pyx

+14-14
Original file line numberDiff line numberDiff line change
@@ -1793,19 +1793,19 @@ def ewma(float64_t[:] vals, float64_t com, int adjust, bint ignore_na, int minp)
17931793
new_wt = 1. if adjust else alpha
17941794

17951795
weighted_avg = vals[0]
1796-
is_observation = (weighted_avg == weighted_avg)
1796+
is_observation = weighted_avg == weighted_avg
17971797
nobs = int(is_observation)
1798-
output[0] = weighted_avg if (nobs >= minp) else NaN
1798+
output[0] = weighted_avg if nobs >= minp else NaN
17991799
old_wt = 1.
18001800

18011801
with nogil:
18021802
for i in range(1, N):
18031803
cur = vals[i]
1804-
is_observation = (cur == cur)
1804+
is_observation = cur == cur
18051805
nobs += is_observation
18061806
if weighted_avg == weighted_avg:
18071807

1808-
if is_observation or (not ignore_na):
1808+
if is_observation or not ignore_na:
18091809

18101810
old_wt *= old_wt_factor
18111811
if is_observation:
@@ -1821,7 +1821,7 @@ def ewma(float64_t[:] vals, float64_t com, int adjust, bint ignore_na, int minp)
18211821
elif is_observation:
18221822
weighted_avg = cur
18231823

1824-
output[i] = weighted_avg if (nobs >= minp) else NaN
1824+
output[i] = weighted_avg if nobs >= minp else NaN
18251825

18261826
return output
18271827

@@ -1851,16 +1851,16 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
18511851
"""
18521852

18531853
cdef:
1854-
Py_ssize_t N = len(input_x)
1854+
Py_ssize_t N = len(input_x), M = len(input_y)
18551855
float64_t alpha, old_wt_factor, new_wt, mean_x, mean_y, cov
18561856
float64_t sum_wt, sum_wt2, old_wt, cur_x, cur_y, old_mean_x, old_mean_y
18571857
float64_t numerator, denominator
18581858
Py_ssize_t i, nobs
18591859
ndarray[float64_t] output
18601860
bint is_observation
18611861

1862-
if <Py_ssize_t>len(input_y) != N:
1863-
raise ValueError(f"arrays are of different lengths ({N} and {len(input_y)})")
1862+
if M != N:
1863+
raise ValueError(f"arrays are of different lengths ({N} and {M})")
18641864

18651865
output = np.empty(N, dtype=float)
18661866
if N == 0:
@@ -1874,12 +1874,12 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
18741874

18751875
mean_x = input_x[0]
18761876
mean_y = input_y[0]
1877-
is_observation = ((mean_x == mean_x) and (mean_y == mean_y))
1877+
is_observation = (mean_x == mean_x) and (mean_y == mean_y)
18781878
nobs = int(is_observation)
18791879
if not is_observation:
18801880
mean_x = NaN
18811881
mean_y = NaN
1882-
output[0] = (0. if bias else NaN) if (nobs >= minp) else NaN
1882+
output[0] = (0. if bias else NaN) if nobs >= minp else NaN
18831883
cov = 0.
18841884
sum_wt = 1.
18851885
sum_wt2 = 1.
@@ -1890,10 +1890,10 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
18901890
for i in range(1, N):
18911891
cur_x = input_x[i]
18921892
cur_y = input_y[i]
1893-
is_observation = ((cur_x == cur_x) and (cur_y == cur_y))
1893+
is_observation = (cur_x == cur_x) and (cur_y == cur_y)
18941894
nobs += is_observation
18951895
if mean_x == mean_x:
1896-
if is_observation or (not ignore_na):
1896+
if is_observation or not ignore_na:
18971897
sum_wt *= old_wt_factor
18981898
sum_wt2 *= (old_wt_factor * old_wt_factor)
18991899
old_wt *= old_wt_factor
@@ -1929,8 +1929,8 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
19291929
if not bias:
19301930
numerator = sum_wt * sum_wt
19311931
denominator = numerator - sum_wt2
1932-
if (denominator > 0.):
1933-
output[i] = ((numerator / denominator) * cov)
1932+
if denominator > 0:
1933+
output[i] = (numerator / denominator) * cov
19341934
else:
19351935
output[i] = NaN
19361936
else:

pandas/core/window/ewm.py

+12-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
from textwrap import dedent
23

34
import numpy as np
@@ -219,7 +220,7 @@ def aggregate(self, func, *args, **kwargs):
219220

220221
agg = aggregate
221222

222-
def _apply(self, func, **kwargs):
223+
def _apply(self, func):
223224
"""
224225
Rolling statistical measure using supplied function. Designed to be
225226
used with passed-in Cython array-based functions.
@@ -253,23 +254,6 @@ def _apply(self, func, **kwargs):
253254
results.append(values.copy())
254255
continue
255256

256-
# if we have a string function name, wrap it
257-
if isinstance(func, str):
258-
cfunc = getattr(window_aggregations, func, None)
259-
if cfunc is None:
260-
raise ValueError(
261-
f"we do not support this function in window_aggregations.{func}"
262-
)
263-
264-
def func(arg):
265-
return cfunc(
266-
arg,
267-
self.com,
268-
int(self.adjust),
269-
int(self.ignore_na),
270-
int(self.min_periods),
271-
)
272-
273257
results.append(np.apply_along_axis(func, self.axis, values))
274258

275259
return self._wrap_results(results, block_list, obj, exclude)
@@ -286,7 +270,15 @@ def mean(self, *args, **kwargs):
286270
Arguments and keyword arguments to be passed into func.
287271
"""
288272
nv.validate_window_func("mean", args, kwargs)
289-
return self._apply("ewma", **kwargs)
273+
window_func = self._get_roll_func("ewma")
274+
window_func = partial(
275+
window_func,
276+
com=self.com,
277+
adjust=int(self.adjust),
278+
ignore_na=self.ignore_na,
279+
minp=int(self.min_periods),
280+
)
281+
return self._apply(window_func)
290282

291283
@Substitution(name="ewm", func_name="std")
292284
@Appender(_doc_template)
@@ -320,7 +312,7 @@ def f(arg):
320312
int(bias),
321313
)
322314

323-
return self._apply(f, **kwargs)
315+
return self._apply(f)
324316

325317
@Substitution(name="ewm", func_name="cov")
326318
@Appender(_doc_template)

0 commit comments

Comments
 (0)