Skip to content

Commit 184e84d

Browse files
BUG: New unit tests for bug pandas-dev#46726: wrong result with varying window size min/max rolling calc.
1 parent 38dd653 commit 184e84d

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

Diff for: pandas/tests/window/test_minmax.py

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from typing import (
2+
Any,
3+
)
4+
5+
import numpy as np
6+
import pytest
7+
8+
import pandas as pd
9+
from pandas import api
10+
import pandas._testing as tm
11+
12+
13+
class StandardWindowIndexer(api.indexers.BaseIndexer):
14+
def __init__(self, n, win_len):
15+
self.n = n
16+
self.win_len = win_len
17+
super().__init__()
18+
19+
def get_window_bounds(
20+
self, num_values=None, min_periods=None, center=None, closed=None, step=None
21+
):
22+
if num_values is None:
23+
num_values = self.n
24+
end = np.arange(num_values, dtype="int64") + 1
25+
start = np.clip(end - self.win_len, 0, num_values)
26+
return start, end
27+
28+
29+
class CustomLengthWindowIndexer(api.indexers.BaseIndexer):
30+
def __init__(self, rnd: np.random.Generator, n, win_len):
31+
self.window = rnd.integers(win_len, size=n)
32+
super().__init__()
33+
34+
def get_window_bounds(
35+
self, num_values=None, min_periods=None, center=None, closed=None, step=None
36+
):
37+
if num_values is None:
38+
num_values = len(self.window)
39+
end = np.arange(num_values, dtype="int64") + 1
40+
start = np.clip(end - self.window, 0, num_values)
41+
return start, end
42+
43+
44+
class ArbitraryWindowIndexer(api.indexers.BaseIndexer):
45+
def __init__(self, rnd: np.random.Generator, n, win_len):
46+
start = rnd.integers(n, size=n)
47+
win_len = rnd.integers(win_len, size=n)
48+
end = np.where(start - win_len >= 0, start - win_len, start + win_len)
49+
50+
(start, end) = (
51+
np.where(end >= start, start, end),
52+
np.where(end >= start, end, start),
53+
)
54+
55+
prm = sorted(range(len(start)), key=lambda i: (end[i], start[i]))
56+
57+
self._start = np.array(start)[prm]
58+
self._end = np.array(end)[prm]
59+
super().__init__()
60+
61+
def get_window_bounds(
62+
self, num_values=None, min_periods=None, center=None, closed=None, step=None
63+
):
64+
if num_values is None:
65+
num_values = len(self._start)
66+
start = np.clip(self._start, 0, num_values)
67+
end = np.clip(self._end, 0, num_values)
68+
return start, end
69+
70+
71+
class TestMinMax:
72+
# Pytest cache will not be a good choice here, because it appears
73+
# pytest persists data on disk, and we are not really interested
74+
# in flooding your hard drive with random numbers.
75+
# Thus we just cache control data in memory to avoid repetititve calculations.
76+
class Cache:
77+
def __init__(self) -> None:
78+
self.ctrl: dict[Any, Any] = {}
79+
80+
@pytest.fixture(scope="class")
81+
def cache(self) -> Cache:
82+
return self.Cache()
83+
84+
@pytest.mark.parametrize("is_max", [True, False])
85+
@pytest.mark.parametrize("engine", ["python", "cython", "numba"])
86+
@pytest.mark.parametrize(
87+
"seed, n, win_len, min_obs, frac_nan, indexer_t",
88+
[
89+
(42, 1000, 80, 15, 0.3, CustomLengthWindowIndexer),
90+
(52, 1000, 80, 15, 0.3, ArbitraryWindowIndexer),
91+
(1984, 1000, 40, 25, 0.3, None),
92+
],
93+
)
94+
def test_minmax(
95+
self, is_max, engine, seed, n, win_len, min_obs, frac_nan, indexer_t, cache
96+
):
97+
if seed is not None and isinstance(seed, np.random._generator.Generator):
98+
rng = np.random.default_rng(seed)
99+
rng.bit_generator.state = seed.bit_generator.state
100+
else:
101+
rng = np.random.default_rng(seed)
102+
103+
if seed is None or isinstance(seed, np.random._generator.Generator):
104+
rng_state_for_key = (
105+
rng.bit_generator.state["bit_generator"],
106+
rng.bit_generator.state["state"]["state"],
107+
rng.bit_generator.state["state"]["inc"],
108+
rng.bit_generator.state["has_uint32"],
109+
rng.bit_generator.state["uinteger"],
110+
)
111+
else:
112+
rng_state_for_key = seed
113+
self.last_rng_state = rng.bit_generator.state
114+
vals = pd.DataFrame({"Data": rng.random(n)})
115+
if frac_nan > 0:
116+
is_nan = rng.random(len(vals)) < frac_nan
117+
vals.Data = np.where(is_nan, np.nan, vals.Data)
118+
119+
ind_obj = (
120+
indexer_t(rng, len(vals), win_len)
121+
if indexer_t
122+
else StandardWindowIndexer(len(vals), win_len)
123+
)
124+
ind_param = ind_obj if indexer_t else win_len
125+
126+
(start, end) = ind_obj.get_window_bounds()
127+
ctrl_key = (is_max, rng_state_for_key, n, win_len, min_obs, frac_nan, indexer_t)
128+
if ctrl_key in cache.ctrl:
129+
ctrl = cache.ctrl[ctrl_key]
130+
else:
131+
# This is brute force calculation, and may get expensive when n is
132+
# large, so we cache it.
133+
ctrl = calc_minmax_control(vals.Data, start, end, min_obs, is_max)
134+
cache.ctrl[ctrl_key] = ctrl
135+
136+
r = vals.rolling(ind_param, min_periods=min_obs)
137+
f = r.max if is_max else r.min
138+
test = f(engine=engine)
139+
tm.assert_series_equal(test.Data, ctrl.Data)
140+
141+
@pytest.mark.parametrize("engine", ["python", "cython", "numba"])
142+
@pytest.mark.parametrize(
143+
"seed, n, win_len, indexer_t",
144+
[
145+
(42, 15, 7, ArbitraryWindowIndexer),
146+
],
147+
)
148+
def test_wrong_order(self, engine, seed, n, win_len, indexer_t):
149+
rng = np.random.default_rng(seed)
150+
vals = pd.DataFrame({"Data": rng.random(n)})
151+
152+
ind_obj = indexer_t(rng, len(vals), win_len)
153+
ind_obj._end[[14, 7]] = ind_obj._end[[7, 14]]
154+
155+
f = vals.rolling(ind_obj).max
156+
with pytest.raises(
157+
ValueError, match="Start/End ordering requirement is violated at index 8"
158+
):
159+
f(engine=engine)
160+
161+
162+
def calc_minmax_control(vals, start, end, min_periods, ismax):
163+
func = np.nanmax if ismax else np.nanmin
164+
outp = np.full(vals.shape, np.nan)
165+
for i in range(len(start)):
166+
if start[i] >= end[i]:
167+
outp[i] = np.nan
168+
else:
169+
rng = vals[start[i] : end[i]]
170+
non_nan_cnt = np.count_nonzero(~np.isnan(rng))
171+
if non_nan_cnt >= min_periods:
172+
outp[i] = func(rng)
173+
else:
174+
outp[i] = np.nan
175+
return pd.DataFrame({"Data": outp})

0 commit comments

Comments
 (0)