Skip to content

Commit a7d6878

Browse files
committed
Refactored and cleaned up (pandas-devGH-45647)
1 parent 221ca7b commit a7d6878

File tree

5 files changed

+119
-104
lines changed

5 files changed

+119
-104
lines changed

doc/source/whatsnew/v1.4.1.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Fixed regressions
2525
Bug fixes
2626
~~~~~~~~~
2727
- Fixed segfault in :meth:``DataFrame.to_json`` when dumping tz-aware datetimes in Python 3.10 (:issue:`42130`)
28-
- Fixed window aggregations to skip over unused elements (:issue:`45647`)
28+
- Fixed window aggregations in :meth:`DataFrame.rolling` and :meth:`Series.rolling` to skip over unused elements (:issue:`45647`)
2929
-
3030

3131
.. ---------------------------------------------------------------------------

pandas/_libs/window/aggregations.pyx

+7-6
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def roll_median_c(const float64_t[:] values, ndarray[int64_t] start,
780780
Py_ssize_t i, j
781781
bint err = False, is_monotonic_increasing_bounds
782782
int midpoint, ret = 0
783-
int64_t nobs, N = len(start), s, e, win
783+
int64_t nobs = 0, N = len(start), s, e, win
784784
float64_t val, res, prev
785785
skiplist_t *sl
786786
ndarray[float64_t] output
@@ -809,9 +809,10 @@ def roll_median_c(const float64_t[:] values, ndarray[int64_t] start,
809809

810810
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
811811

812-
skiplist_destroy(sl)
813-
sl = skiplist_init(<int>win)
814-
nobs = 0
812+
if i != 0:
813+
skiplist_destroy(sl)
814+
sl = skiplist_init(<int>win)
815+
nobs = 0
815816
# setup
816817
for j in range(s, e):
817818
val = values[j]
@@ -1088,7 +1089,7 @@ def roll_quantile(const float64_t[:] values, ndarray[int64_t] start,
10881089
e = end[i]
10891090

10901091
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
1091-
if not is_monotonic_increasing_bounds or s >= end[i - 1]:
1092+
if i != 0:
10921093
nobs = 0
10931094
skiplist_destroy(skiplist)
10941095
skiplist = skiplist_init(<int>win)
@@ -1213,7 +1214,7 @@ def roll_rank(const float64_t[:] values, ndarray[int64_t] start,
12131214
e = end[i]
12141215

12151216
if i == 0 or not is_monotonic_increasing_bounds or s >= end[i - 1]:
1216-
if not is_monotonic_increasing_bounds or s >= end[i - 1]:
1217+
if i != 0:
12171218
nobs = 0
12181219
skiplist_destroy(skiplist)
12191220
skiplist = skiplist_init(<int>win)

pandas/tests/window/conftest.py

-61
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
datetime,
33
timedelta,
44
)
5-
from functools import partial
65

76
import numpy as np
87
import pytest
98

10-
import pandas._libs.window.aggregations as window_aggregations
119
import pandas.util._test_decorators as td
1210

1311
from pandas import (
@@ -261,62 +259,3 @@ def frame():
261259
index=bdate_range(datetime(2009, 1, 1), periods=100),
262260
columns=np.arange(10),
263261
)
264-
265-
266-
def _named_func(name_and_func):
267-
name, func = name_and_func
268-
if not hasattr(func, "func"):
269-
func = partial(func)
270-
func.__name__ = name
271-
return func
272-
273-
274-
@pytest.fixture(
275-
params=[
276-
_named_func(x)
277-
for x in [
278-
("roll_sum", window_aggregations.roll_sum),
279-
("roll_mean", window_aggregations.roll_mean),
280-
]
281-
+ [
282-
(f"roll_var({ddof})", partial(window_aggregations.roll_var, ddof=ddof))
283-
for ddof in [0, 1]
284-
]
285-
+ [
286-
("roll_skew", window_aggregations.roll_skew),
287-
("roll_kurt", window_aggregations.roll_kurt),
288-
("roll_median_c", window_aggregations.roll_median_c),
289-
("roll_max", window_aggregations.roll_max),
290-
("roll_min", window_aggregations.roll_min),
291-
]
292-
+ [
293-
(
294-
f"roll_quantile({quantile},{interpolation})",
295-
partial(
296-
window_aggregations.roll_quantile,
297-
quantile=quantile,
298-
interpolation=interpolation,
299-
),
300-
)
301-
for quantile in [0.0001, 0.5, 0.9999]
302-
for interpolation in window_aggregations.interpolation_types
303-
]
304-
+ [
305-
(
306-
f"roll_rank({percentile},{method},{ascending})",
307-
partial(
308-
window_aggregations.roll_rank,
309-
percentile=percentile,
310-
method=method,
311-
ascending=ascending,
312-
),
313-
)
314-
for percentile in [True, False]
315-
for method in window_aggregations.rolling_rank_tiebreakers.keys()
316-
for ascending in [True, False]
317-
]
318-
]
319-
)
320-
def rolling_aggregation(request):
321-
"""Make a named rolling aggregation function as fixture."""
322-
return request.param
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from functools import partial
2+
import sys
3+
4+
import numpy as np
5+
import pytest
6+
7+
import pandas._libs.window.aggregations as window_aggregations
8+
9+
from pandas import Series
10+
import pandas._testing as tm
11+
12+
13+
def _get_rolling_aggregations():
14+
# list pairs of name and function
15+
# each function has this signature:
16+
# (const float64_t[:] values, ndarray[int64_t] start,
17+
# ndarray[int64_t] end, int64_t minp) -> np.ndarray
18+
named_roll_aggs = (
19+
[
20+
("roll_sum", window_aggregations.roll_sum),
21+
("roll_mean", window_aggregations.roll_mean),
22+
]
23+
+ [
24+
(f"roll_var({ddof})", partial(window_aggregations.roll_var, ddof=ddof))
25+
for ddof in [0, 1]
26+
]
27+
+ [
28+
("roll_skew", window_aggregations.roll_skew),
29+
("roll_kurt", window_aggregations.roll_kurt),
30+
("roll_median_c", window_aggregations.roll_median_c),
31+
("roll_max", window_aggregations.roll_max),
32+
("roll_min", window_aggregations.roll_min),
33+
]
34+
+ [
35+
(
36+
f"roll_quantile({quantile},{interpolation})",
37+
partial(
38+
window_aggregations.roll_quantile,
39+
quantile=quantile,
40+
interpolation=interpolation,
41+
),
42+
)
43+
for quantile in [0.0001, 0.5, 0.9999]
44+
for interpolation in window_aggregations.interpolation_types
45+
]
46+
+ [
47+
(
48+
f"roll_rank({percentile},{method},{ascending})",
49+
partial(
50+
window_aggregations.roll_rank,
51+
percentile=percentile,
52+
method=method,
53+
ascending=ascending,
54+
),
55+
)
56+
for percentile in [True, False]
57+
for method in window_aggregations.rolling_rank_tiebreakers.keys()
58+
for ascending in [True, False]
59+
]
60+
)
61+
# unzip to a list of 2 tuples, names and functions
62+
unzipped = list(zip(*named_roll_aggs))
63+
return {"ids": unzipped[0], "params": unzipped[1]}
64+
65+
66+
_rolling_aggregations = _get_rolling_aggregations()
67+
68+
69+
@pytest.fixture(
70+
params=_rolling_aggregations["params"], ids=_rolling_aggregations["ids"]
71+
)
72+
def rolling_aggregation(request):
73+
"""Make a rolling aggregation function as fixture."""
74+
return request.param
75+
76+
77+
def test_rolling_aggregation_boundary_consistency(rolling_aggregation):
78+
# GH-45647
79+
minp, step, width, size, selection = 0, 1, 3, 11, [2, 7]
80+
values = np.arange(1, 1 + size, dtype=np.float64)
81+
end = np.arange(width, size, step, dtype=np.int64)
82+
start = end - width
83+
selarr = np.array(selection, dtype=np.int32)
84+
result = Series(rolling_aggregation(values, start[selarr], end[selarr], minp))
85+
expected = Series(rolling_aggregation(values, start, end, minp)[selarr])
86+
tm.assert_equal(expected, result)
87+
88+
89+
def test_rolling_aggregation_with_unused_elements(rolling_aggregation):
90+
# GH-45647
91+
minp, width = 0, 5 # width at least 4 for kurt
92+
size = 2 * width + 5
93+
values = np.arange(1, size + 1, dtype=np.float64)
94+
values[width : width + 2] = sys.float_info.min
95+
values[width + 2] = np.nan
96+
values[width + 3 : width + 5] = sys.float_info.max
97+
start = np.array([0, size - width], dtype=np.int64)
98+
end = np.array([width, size], dtype=np.int64)
99+
loc = np.array(
100+
[j for i in range(len(start)) for j in range(start[i], end[i])],
101+
dtype=np.int32,
102+
)
103+
result = Series(rolling_aggregation(values, start, end, minp))
104+
compact_values = np.array(values[loc], dtype=np.float64)
105+
compact_start = np.arange(0, len(start) * width, width, dtype=np.int64)
106+
compact_end = compact_start + width
107+
expected = Series(
108+
rolling_aggregation(compact_values, compact_start, compact_end, minp)
109+
)
110+
assert np.isfinite(expected.values).all(), "Not all expected values are finite"
111+
tm.assert_equal(expected, result)

pandas/tests/window/test_rolling.py

-36
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
datetime,
33
timedelta,
44
)
5-
import sys
65

76
import numpy as np
87
import pytest
@@ -1735,38 +1734,3 @@ def test_rolling_std_neg_sqrt():
17351734

17361735
b = a.ewm(span=3).std()
17371736
assert np.isfinite(b[2:]).all()
1738-
1739-
1740-
def test_rolling_aggregation_boundary_consistency(rolling_aggregation):
1741-
# GH-45647
1742-
minp, step, width, size, selection = 0, 1, 3, 11, [2, 7]
1743-
s = Series(np.arange(1, 1 + size, dtype=np.float64))
1744-
end = np.arange(width, size, step, dtype=np.int64)
1745-
start = end - width
1746-
selarr = np.array(selection, dtype=np.int32)
1747-
result = Series(rolling_aggregation(s.values, start[selarr], end[selarr], minp))
1748-
expected = Series(rolling_aggregation(s.values, start, end, minp)[selarr])
1749-
tm.assert_equal(expected, result)
1750-
1751-
1752-
def test_rolling_aggregation_with_unused_elements(rolling_aggregation):
1753-
# GH-45647
1754-
minp, width = 0, 5 # width at least 4 for kurt
1755-
size = 2 * width + 5
1756-
s = Series(np.arange(1, size + 1, dtype=np.float64))
1757-
s[width : width + 2] = sys.float_info.min
1758-
s[width + 2] = np.nan
1759-
s[width + 3 : width + 5] = sys.float_info.max
1760-
start = np.array([0, size - width], dtype=np.int64)
1761-
end = np.array([width, size], dtype=np.int64)
1762-
loc = np.array(
1763-
[j for i in range(len(start)) for j in range(start[i], end[i])],
1764-
dtype=np.int32,
1765-
)
1766-
result = Series(rolling_aggregation(s.values, start, end, minp))
1767-
compact_s = np.array(s.iloc[loc], dtype=np.float64)
1768-
compact_start = np.arange(0, len(start) * width, width, dtype=np.int64)
1769-
compact_end = compact_start + width
1770-
expected = Series(rolling_aggregation(compact_s, compact_start, compact_end, minp))
1771-
assert np.isfinite(expected.values).all(), "Not all expected values are finite"
1772-
tm.assert_equal(expected, result)

0 commit comments

Comments
 (0)