Skip to content

Commit 3d44221

Browse files
Licht-Tjreback
authored andcommitted
BUG: Fix inaccurate rolling.var calculation (#18481)
1 parent b69c1a2 commit 3d44221

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

doc/source/whatsnew/v0.21.1.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ Groupby/Resample/Rolling
103103
- Bug in ``DataFrame.resample(...).apply(...)`` when there is a callable that returns different columns (:issue:`15169`)
104104
- Bug in ``DataFrame.resample(...)`` when there is a time change (DST) and resampling frequecy is 12h or higher (:issue:`15549`)
105105
- Bug in ``pd.DataFrameGroupBy.count()`` when counting over a datetimelike column (:issue:`13393`)
106-
-
106+
- Bug in ``rolling.var`` where calculation is inaccurate with a zero-valued array (:issue:`18430`)
107107
-
108108
-
109109

pandas/_libs/window.pyx

+17-8
Original file line numberDiff line numberDiff line change
@@ -661,9 +661,11 @@ cdef inline void add_var(double val, double *nobs, double *mean_x,
661661
if val == val:
662662
nobs[0] = nobs[0] + 1
663663

664-
delta = (val - mean_x[0])
664+
# a part of Welford's method for the online variance-calculation
665+
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
666+
delta = val - mean_x[0]
665667
mean_x[0] = mean_x[0] + delta / nobs[0]
666-
ssqdm_x[0] = ssqdm_x[0] + delta * (val - mean_x[0])
668+
ssqdm_x[0] = ssqdm_x[0] + ((nobs[0] - 1) * delta ** 2) / nobs[0]
667669

668670

669671
cdef inline void remove_var(double val, double *nobs, double *mean_x,
@@ -675,9 +677,11 @@ cdef inline void remove_var(double val, double *nobs, double *mean_x,
675677
if val == val:
676678
nobs[0] = nobs[0] - 1
677679
if nobs[0]:
678-
delta = (val - mean_x[0])
680+
# a part of Welford's method for the online variance-calculation
681+
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
682+
delta = val - mean_x[0]
679683
mean_x[0] = mean_x[0] - delta / nobs[0]
680-
ssqdm_x[0] = ssqdm_x[0] - delta * (val - mean_x[0])
684+
ssqdm_x[0] = ssqdm_x[0] - ((nobs[0] + 1) * delta ** 2) / nobs[0]
681685
else:
682686
mean_x[0] = 0
683687
ssqdm_x[0] = 0
@@ -689,7 +693,7 @@ def roll_var(ndarray[double_t] input, int64_t win, int64_t minp,
689693
Numerically stable implementation using Welford's method.
690694
"""
691695
cdef:
692-
double val, prev, mean_x = 0, ssqdm_x = 0, nobs = 0, delta
696+
double val, prev, mean_x = 0, ssqdm_x = 0, nobs = 0, delta, mean_x_old
693697
int64_t s, e
694698
bint is_variable
695699
Py_ssize_t i, j, N
@@ -749,6 +753,9 @@ def roll_var(ndarray[double_t] input, int64_t win, int64_t minp,
749753
add_var(input[i], &nobs, &mean_x, &ssqdm_x)
750754
output[i] = calc_var(minp, ddof, nobs, ssqdm_x)
751755

756+
# a part of Welford's method for the online variance-calculation
757+
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
758+
752759
# After the first window, observations can both be added and
753760
# removed
754761
for i from win <= i < N:
@@ -760,10 +767,12 @@ def roll_var(ndarray[double_t] input, int64_t win, int64_t minp,
760767

761768
# Adding one observation and removing another one
762769
delta = val - prev
763-
prev -= mean_x
770+
mean_x_old = mean_x
771+
764772
mean_x += delta / nobs
765-
val -= mean_x
766-
ssqdm_x += (val + prev) * delta
773+
ssqdm_x += ((nobs - 1) * val
774+
+ (nobs + 1) * prev
775+
- 2 * nobs * mean_x_old) * delta / nobs
767776

768777
else:
769778
add_var(val, &nobs, &mean_x, &ssqdm_x)

pandas/tests/test_window.py

+8
Original file line numberDiff line numberDiff line change
@@ -2482,6 +2482,14 @@ def test_rolling_corr_pairwise(self):
24822482
self._check_pairwise_moment('rolling', 'corr', window=10,
24832483
min_periods=5)
24842484

2485+
@pytest.mark.parametrize('window', range(7))
2486+
def test_rolling_corr_with_zero_variance(self, window):
2487+
# GH 18430
2488+
s = pd.Series(np.zeros(20))
2489+
other = pd.Series(np.arange(20))
2490+
2491+
assert s.rolling(window=window).corr(other=other).isna().all()
2492+
24852493
def _check_pairwise_moment(self, dispatch, name, **kwargs):
24862494
def get_result(obj, obj2=None):
24872495
return getattr(getattr(obj, dispatch)(**kwargs), name)(obj2)

0 commit comments

Comments
 (0)