Skip to content

Commit 020f290

Browse files
authored
BUG: Fix window aggregations to skip over unused elements (GH-45647) (#45655)
1 parent f0bd3af commit 020f290

File tree

3 files changed

+168
-36
lines changed

3 files changed

+168
-36
lines changed

doc/source/whatsnew/v1.4.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Fixed regressions
2828
Bug fixes
2929
~~~~~~~~~
3030
- Fixed segfault in :meth:``DataFrame.to_json`` when dumping tz-aware datetimes in Python 3.10 (:issue:`42130`)
31+
- Fixed window aggregations in :meth:`DataFrame.rolling` and :meth:`Series.rolling` to skip over unused elements (:issue:`45647`)
3132
-
3233

3334
.. ---------------------------------------------------------------------------

pandas/_libs/window/aggregations.pyx

+56-36
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ def roll_sum(const float64_t[:] values, ndarray[int64_t] start,
122122
ndarray[int64_t] end, int64_t minp) -> np.ndarray:
123123
cdef:
124124
Py_ssize_t i, j
125-
float64_t sum_x = 0, compensation_add = 0, compensation_remove = 0
125+
float64_t sum_x, compensation_add, compensation_remove
126126
int64_t s, e
127-
int64_t nobs = 0, N = len(values)
127+
int64_t nobs = 0, N = len(start)
128128
ndarray[float64_t] output
129129
bint is_monotonic_increasing_bounds
130130

@@ -139,10 +139,12 @@ def roll_sum(const float64_t[:] values, ndarray[int64_t] start,
139139
s = start[i]
140140
e = end[i]
141141

142-
if i == 0 or not is_monotonic_increasing_bounds:
142+
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
143143

144144
# setup
145145

146+
sum_x = compensation_add = compensation_remove = 0
147+
nobs = 0
146148
for j in range(s, e):
147149
add_sum(values[j], &nobs, &sum_x, &compensation_add)
148150

@@ -226,9 +228,9 @@ cdef inline void remove_mean(float64_t val, Py_ssize_t *nobs, float64_t *sum_x,
226228
def roll_mean(const float64_t[:] values, ndarray[int64_t] start,
227229
ndarray[int64_t] end, int64_t minp) -> np.ndarray:
228230
cdef:
229-
float64_t val, compensation_add = 0, compensation_remove = 0, sum_x = 0
231+
float64_t val, compensation_add, compensation_remove, sum_x
230232
int64_t s, e
231-
Py_ssize_t nobs = 0, i, j, neg_ct = 0, N = len(values)
233+
Py_ssize_t nobs, i, j, neg_ct, N = len(start)
232234
ndarray[float64_t] output
233235
bint is_monotonic_increasing_bounds
234236

@@ -243,8 +245,10 @@ def roll_mean(const float64_t[:] values, ndarray[int64_t] start,
243245
s = start[i]
244246
e = end[i]
245247

246-
if i == 0 or not is_monotonic_increasing_bounds:
248+
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
247249

250+
compensation_add = compensation_remove = sum_x = 0
251+
nobs = neg_ct = 0
248252
# setup
249253
for j in range(s, e):
250254
val = values[j]
@@ -349,11 +353,11 @@ def roll_var(const float64_t[:] values, ndarray[int64_t] start,
349353
Numerically stable implementation using Welford's method.
350354
"""
351355
cdef:
352-
float64_t mean_x = 0, ssqdm_x = 0, nobs = 0, compensation_add = 0,
353-
float64_t compensation_remove = 0,
356+
float64_t mean_x, ssqdm_x, nobs, compensation_add,
357+
float64_t compensation_remove,
354358
float64_t val, prev, delta, mean_x_old
355359
int64_t s, e
356-
Py_ssize_t i, j, N = len(values)
360+
Py_ssize_t i, j, N = len(start)
357361
ndarray[float64_t] output
358362
bint is_monotonic_increasing_bounds
359363

@@ -372,8 +376,9 @@ def roll_var(const float64_t[:] values, ndarray[int64_t] start,
372376

373377
# Over the first window, observations can only be added
374378
# never removed
375-
if i == 0 or not is_monotonic_increasing_bounds:
379+
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
376380

381+
mean_x = ssqdm_x = nobs = compensation_add = compensation_remove = 0
377382
for j in range(s, e):
378383
add_var(values[j], &nobs, &mean_x, &ssqdm_x, &compensation_add)
379384

@@ -500,11 +505,11 @@ def roll_skew(ndarray[float64_t] values, ndarray[int64_t] start,
500505
cdef:
501506
Py_ssize_t i, j
502507
float64_t val, prev, min_val, mean_val, sum_val = 0
503-
float64_t compensation_xxx_add = 0, compensation_xxx_remove = 0
504-
float64_t compensation_xx_add = 0, compensation_xx_remove = 0
505-
float64_t compensation_x_add = 0, compensation_x_remove = 0
506-
float64_t x = 0, xx = 0, xxx = 0
507-
int64_t nobs = 0, N = len(values), nobs_mean = 0
508+
float64_t compensation_xxx_add, compensation_xxx_remove
509+
float64_t compensation_xx_add, compensation_xx_remove
510+
float64_t compensation_x_add, compensation_x_remove
511+
float64_t x, xx, xxx
512+
int64_t nobs = 0, N = len(start), V = len(values), nobs_mean = 0
508513
int64_t s, e
509514
ndarray[float64_t] output, mean_array, values_copy
510515
bint is_monotonic_increasing_bounds
@@ -518,7 +523,7 @@ def roll_skew(ndarray[float64_t] values, ndarray[int64_t] start,
518523
values_copy = np.copy(values)
519524

520525
with nogil:
521-
for i in range(0, N):
526+
for i in range(0, V):
522527
val = values_copy[i]
523528
if notnan(val):
524529
nobs_mean += 1
@@ -527,7 +532,7 @@ def roll_skew(ndarray[float64_t] values, ndarray[int64_t] start,
527532
# Other cases would lead to imprecision for smallest values
528533
if min_val - mean_val > -1e5:
529534
mean_val = round(mean_val)
530-
for i in range(0, N):
535+
for i in range(0, V):
531536
values_copy[i] = values_copy[i] - mean_val
532537

533538
for i in range(0, N):
@@ -537,8 +542,13 @@ def roll_skew(ndarray[float64_t] values, ndarray[int64_t] start,
537542

538543
# Over the first window, observations can only be added
539544
# never removed
540-
if i == 0 or not is_monotonic_increasing_bounds:
545+
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
541546

547+
compensation_xxx_add = compensation_xxx_remove = 0
548+
compensation_xx_add = compensation_xx_remove = 0
549+
compensation_x_add = compensation_x_remove = 0
550+
x = xx = xxx = 0
551+
nobs = 0
542552
for j in range(s, e):
543553
val = values_copy[j]
544554
add_skew(val, &nobs, &x, &xx, &xxx, &compensation_x_add,
@@ -682,12 +692,12 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
682692
cdef:
683693
Py_ssize_t i, j
684694
float64_t val, prev, mean_val, min_val, sum_val = 0
685-
float64_t compensation_xxxx_add = 0, compensation_xxxx_remove = 0
686-
float64_t compensation_xxx_remove = 0, compensation_xxx_add = 0
687-
float64_t compensation_xx_remove = 0, compensation_xx_add = 0
688-
float64_t compensation_x_remove = 0, compensation_x_add = 0
689-
float64_t x = 0, xx = 0, xxx = 0, xxxx = 0
690-
int64_t nobs = 0, s, e, N = len(values), nobs_mean = 0
695+
float64_t compensation_xxxx_add, compensation_xxxx_remove
696+
float64_t compensation_xxx_remove, compensation_xxx_add
697+
float64_t compensation_xx_remove, compensation_xx_add
698+
float64_t compensation_x_remove, compensation_x_add
699+
float64_t x, xx, xxx, xxxx
700+
int64_t nobs, s, e, N = len(start), V = len(values), nobs_mean = 0
691701
ndarray[float64_t] output, values_copy
692702
bint is_monotonic_increasing_bounds
693703

@@ -700,7 +710,7 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
700710
min_val = np.nanmin(values)
701711

702712
with nogil:
703-
for i in range(0, N):
713+
for i in range(0, V):
704714
val = values_copy[i]
705715
if notnan(val):
706716
nobs_mean += 1
@@ -709,7 +719,7 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
709719
# Other cases would lead to imprecision for smallest values
710720
if min_val - mean_val > -1e4:
711721
mean_val = round(mean_val)
712-
for i in range(0, N):
722+
for i in range(0, V):
713723
values_copy[i] = values_copy[i] - mean_val
714724

715725
for i in range(0, N):
@@ -719,8 +729,14 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
719729

720730
# Over the first window, observations can only be added
721731
# never removed
722-
if i == 0 or not is_monotonic_increasing_bounds:
732+
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
723733

734+
compensation_xxxx_add = compensation_xxxx_remove = 0
735+
compensation_xxx_remove = compensation_xxx_add = 0
736+
compensation_xx_remove = compensation_xx_add = 0
737+
compensation_x_remove = compensation_x_add = 0
738+
x = xx = xxx = xxxx = 0
739+
nobs = 0
724740
for j in range(s, e):
725741
add_kurt(values_copy[j], &nobs, &x, &xx, &xxx, &xxxx,
726742
&compensation_x_add, &compensation_xx_add,
@@ -764,7 +780,7 @@ def roll_median_c(const float64_t[:] values, ndarray[int64_t] start,
764780
Py_ssize_t i, j
765781
bint err = False, is_monotonic_increasing_bounds
766782
int midpoint, ret = 0
767-
int64_t nobs = 0, N = len(values), s, e, win
783+
int64_t nobs = 0, N = len(start), s, e, win
768784
float64_t val, res, prev
769785
skiplist_t *sl
770786
ndarray[float64_t] output
@@ -791,8 +807,12 @@ def roll_median_c(const float64_t[:] values, ndarray[int64_t] start,
791807
s = start[i]
792808
e = end[i]
793809

794-
if i == 0 or not is_monotonic_increasing_bounds:
810+
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
795811

812+
if i != 0:
813+
skiplist_destroy(sl)
814+
sl = skiplist_init(<int>win)
815+
nobs = 0
796816
# setup
797817
for j in range(s, e):
798818
val = values[j]
@@ -948,7 +968,7 @@ cdef _roll_min_max(ndarray[numeric_t] values,
948968
cdef:
949969
numeric_t ai
950970
int64_t curr_win_size, start
951-
Py_ssize_t i, k, nobs = 0, N = len(values)
971+
Py_ssize_t i, k, nobs = 0, N = len(starti)
952972
deque Q[int64_t] # min/max always the front
953973
deque W[int64_t] # track the whole window for nobs compute
954974
ndarray[float64_t, ndim=1] output
@@ -1031,7 +1051,7 @@ def roll_quantile(const float64_t[:] values, ndarray[int64_t] start,
10311051
O(N log(window)) implementation using skip list
10321052
"""
10331053
cdef:
1034-
Py_ssize_t i, j, s, e, N = len(values), idx
1054+
Py_ssize_t i, j, s, e, N = len(start), idx
10351055
int ret = 0
10361056
int64_t nobs = 0, win
10371057
float64_t val, prev, midpoint, idx_with_fraction
@@ -1068,8 +1088,8 @@ def roll_quantile(const float64_t[:] values, ndarray[int64_t] start,
10681088
s = start[i]
10691089
e = end[i]
10701090

1071-
if i == 0 or not is_monotonic_increasing_bounds:
1072-
if not is_monotonic_increasing_bounds:
1091+
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
1092+
if i != 0:
10731093
nobs = 0
10741094
skiplist_destroy(skiplist)
10751095
skiplist = skiplist_init(<int>win)
@@ -1160,7 +1180,7 @@ def roll_rank(const float64_t[:] values, ndarray[int64_t] start,
11601180
derived from roll_quantile
11611181
"""
11621182
cdef:
1163-
Py_ssize_t i, j, s, e, N = len(values), idx
1183+
Py_ssize_t i, j, s, e, N = len(start), idx
11641184
float64_t rank_min = 0, rank = 0
11651185
int64_t nobs = 0, win
11661186
float64_t val
@@ -1193,8 +1213,8 @@ def roll_rank(const float64_t[:] values, ndarray[int64_t] start,
11931213
s = start[i]
11941214
e = end[i]
11951215

1196-
if i == 0 or not is_monotonic_increasing_bounds:
1197-
if not is_monotonic_increasing_bounds:
1216+
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
1217+
if i != 0:
11981218
nobs = 0
11991219
skiplist_destroy(skiplist)
12001220
skiplist = skiplist_init(<int>win)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from functools import partial
2+
import sys
3+
4+
import numpy as np
5+
import pytest
6+
7+
import pandas._libs.window.aggregations as window_aggregations
8+
9+
from pandas import Series
10+
import pandas._testing as tm
11+
12+
13+
def _get_rolling_aggregations():
14+
# list pairs of name and function
15+
# each function has this signature:
16+
# (const float64_t[:] values, ndarray[int64_t] start,
17+
# ndarray[int64_t] end, int64_t minp) -> np.ndarray
18+
named_roll_aggs = (
19+
[
20+
("roll_sum", window_aggregations.roll_sum),
21+
("roll_mean", window_aggregations.roll_mean),
22+
]
23+
+ [
24+
(f"roll_var({ddof})", partial(window_aggregations.roll_var, ddof=ddof))
25+
for ddof in [0, 1]
26+
]
27+
+ [
28+
("roll_skew", window_aggregations.roll_skew),
29+
("roll_kurt", window_aggregations.roll_kurt),
30+
("roll_median_c", window_aggregations.roll_median_c),
31+
("roll_max", window_aggregations.roll_max),
32+
("roll_min", window_aggregations.roll_min),
33+
]
34+
+ [
35+
(
36+
f"roll_quantile({quantile},{interpolation})",
37+
partial(
38+
window_aggregations.roll_quantile,
39+
quantile=quantile,
40+
interpolation=interpolation,
41+
),
42+
)
43+
for quantile in [0.0001, 0.5, 0.9999]
44+
for interpolation in window_aggregations.interpolation_types
45+
]
46+
+ [
47+
(
48+
f"roll_rank({percentile},{method},{ascending})",
49+
partial(
50+
window_aggregations.roll_rank,
51+
percentile=percentile,
52+
method=method,
53+
ascending=ascending,
54+
),
55+
)
56+
for percentile in [True, False]
57+
for method in window_aggregations.rolling_rank_tiebreakers.keys()
58+
for ascending in [True, False]
59+
]
60+
)
61+
# unzip to a list of 2 tuples, names and functions
62+
unzipped = list(zip(*named_roll_aggs))
63+
return {"ids": unzipped[0], "params": unzipped[1]}
64+
65+
66+
_rolling_aggregations = _get_rolling_aggregations()
67+
68+
69+
@pytest.fixture(
70+
params=_rolling_aggregations["params"], ids=_rolling_aggregations["ids"]
71+
)
72+
def rolling_aggregation(request):
73+
"""Make a rolling aggregation function as fixture."""
74+
return request.param
75+
76+
77+
def test_rolling_aggregation_boundary_consistency(rolling_aggregation):
78+
# GH-45647
79+
minp, step, width, size, selection = 0, 1, 3, 11, [2, 7]
80+
values = np.arange(1, 1 + size, dtype=np.float64)
81+
end = np.arange(width, size, step, dtype=np.int64)
82+
start = end - width
83+
selarr = np.array(selection, dtype=np.int32)
84+
result = Series(rolling_aggregation(values, start[selarr], end[selarr], minp))
85+
expected = Series(rolling_aggregation(values, start, end, minp)[selarr])
86+
tm.assert_equal(expected, result)
87+
88+
89+
def test_rolling_aggregation_with_unused_elements(rolling_aggregation):
90+
# GH-45647
91+
minp, width = 0, 5 # width at least 4 for kurt
92+
size = 2 * width + 5
93+
values = np.arange(1, size + 1, dtype=np.float64)
94+
values[width : width + 2] = sys.float_info.min
95+
values[width + 2] = np.nan
96+
values[width + 3 : width + 5] = sys.float_info.max
97+
start = np.array([0, size - width], dtype=np.int64)
98+
end = np.array([width, size], dtype=np.int64)
99+
loc = np.array(
100+
[j for i in range(len(start)) for j in range(start[i], end[i])],
101+
dtype=np.int32,
102+
)
103+
result = Series(rolling_aggregation(values, start, end, minp))
104+
compact_values = np.array(values[loc], dtype=np.float64)
105+
compact_start = np.arange(0, len(start) * width, width, dtype=np.int64)
106+
compact_end = compact_start + width
107+
expected = Series(
108+
rolling_aggregation(compact_values, compact_start, compact_end, minp)
109+
)
110+
assert np.isfinite(expected.values).all(), "Not all expected values are finite"
111+
tm.assert_equal(expected, result)

0 commit comments

Comments
 (0)