Skip to content

Commit 5230845

Browse files
BUG: New unit tests for bug pandas-dev#46726: wrong result with varying window size min/max rolling calc.
1 parent 38dd653 commit 5230845

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed

pandas/tests/window/test_minmax.py

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from typing import (
2+
Any,
3+
)
4+
5+
import numpy as np
6+
import pytest
7+
8+
import pandas as pd
9+
from pandas import api
10+
import pandas._testing as tm
11+
12+
13+
class StandardWindowIndexer(api.indexers.BaseIndexer):
14+
def __init__(self, n, win_len):
15+
self.n = n
16+
self.win_len = win_len
17+
super().__init__()
18+
19+
def get_window_bounds(
20+
self, num_values=None, min_periods=None, center=None, closed=None, step=None
21+
):
22+
if num_values is None:
23+
num_values = self.n
24+
end = np.arange(num_values, dtype="int64") + 1
25+
start = np.clip(end - self.win_len, 0, num_values)
26+
return start, end
27+
28+
29+
class CustomLengthWindowIndexer(api.indexers.BaseIndexer):
30+
def __init__(self, rnd: np.random.Generator, n, win_len):
31+
self.window = rnd.integers(win_len, size=n)
32+
super().__init__()
33+
34+
def get_window_bounds(
35+
self, num_values=None, min_periods=None, center=None, closed=None, step=None
36+
):
37+
if num_values is None:
38+
num_values = len(self.window)
39+
end = np.arange(num_values, dtype="int64") + 1
40+
start = np.clip(end - self.window, 0, num_values)
41+
return start, end
42+
43+
44+
class ArbitraryWindowIndexer(api.indexers.BaseIndexer):
45+
def __init__(self, rnd: np.random.Generator, n, win_len):
46+
start = rnd.integers(n, size=n)
47+
win_len = rnd.integers(win_len, size=n)
48+
end = np.where(start - win_len >= 0, start - win_len, start + win_len)
49+
50+
(start, end) = (
51+
np.where(end >= start, start, end),
52+
np.where(end >= start, end, start),
53+
)
54+
55+
# It is extremely unlikely that a random array would come sorted,
56+
# so we proceed with sort without checking if it is sorted.
57+
prm = sorted(range(len(start)), key=lambda i: (end[i], start[i]))
58+
59+
self._start = np.array(start)[prm]
60+
self._end = np.array(end)[prm]
61+
super().__init__()
62+
63+
def get_window_bounds(
64+
self, num_values=None, min_periods=None, center=None, closed=None, step=None
65+
):
66+
if num_values is None:
67+
num_values = len(self._start)
68+
start = np.clip(self._start, 0, num_values)
69+
end = np.clip(self._end, 0, num_values)
70+
return start, end
71+
72+
73+
class TestMinMax:
74+
# Pytest cache will not be a good choice here, because it appears
75+
# pytest persists data on disk, and we are not really interested
76+
# in flooding your hard drive with random numbers.
77+
# Thus we just cache control data in memory to avoid repetititve calculations.
78+
class Cache:
79+
def __init__(self) -> None:
80+
self.ctrl: dict[Any, Any] = {}
81+
82+
@pytest.fixture(scope="class")
83+
def cache(self) -> Cache:
84+
return self.Cache()
85+
86+
@pytest.mark.parametrize("is_max", [True, False])
87+
# @pytest.mark.parametrize("engine", ["python", "cython", "numba"])
88+
@pytest.mark.parametrize("engine", ["python", "cython"])
89+
@pytest.mark.parametrize(
90+
"seed, n, win_len, min_obs, frac_nan, indexer_t",
91+
[
92+
(42, 1000, 80, 15, 0.3, CustomLengthWindowIndexer),
93+
(52, 1000, 80, 15, 0.3, ArbitraryWindowIndexer),
94+
(1984, 1000, 40, 25, 0.3, None),
95+
],
96+
)
97+
def test_minmax(
98+
self, is_max, engine, seed, n, win_len, min_obs, frac_nan, indexer_t, cache
99+
):
100+
if seed is not None and isinstance(seed, np.random._generator.Generator):
101+
rng = np.random.default_rng(seed)
102+
rng.bit_generator.state = seed.bit_generator.state
103+
else:
104+
rng = np.random.default_rng(seed)
105+
106+
if seed is None or isinstance(seed, np.random._generator.Generator):
107+
rng_state_for_key = (
108+
rng.bit_generator.state["bit_generator"],
109+
rng.bit_generator.state["state"]["state"],
110+
rng.bit_generator.state["state"]["inc"],
111+
rng.bit_generator.state["has_uint32"],
112+
rng.bit_generator.state["uinteger"],
113+
)
114+
else:
115+
rng_state_for_key = seed
116+
self.last_rng_state = rng.bit_generator.state
117+
vals = pd.DataFrame({"Data": rng.random(n)})
118+
if frac_nan > 0:
119+
is_nan = rng.random(len(vals)) < frac_nan
120+
vals.Data = np.where(is_nan, np.nan, vals.Data)
121+
122+
ind_obj = (
123+
indexer_t(rng, len(vals), win_len)
124+
if indexer_t
125+
else StandardWindowIndexer(len(vals), win_len)
126+
)
127+
ind_param = ind_obj if indexer_t else win_len
128+
129+
(start, end) = ind_obj.get_window_bounds()
130+
ctrl_key = (is_max, rng_state_for_key, n, win_len, min_obs, frac_nan, indexer_t)
131+
if ctrl_key in cache.ctrl:
132+
ctrl = cache.ctrl[ctrl_key]
133+
else:
134+
# This is brute force calculation, and may get expensive when n is
135+
# large, so we cache it.
136+
ctrl = calc_minmax_control(vals.Data, start, end, min_obs, is_max)
137+
cache.ctrl[ctrl_key] = ctrl
138+
139+
r = vals.rolling(ind_param, min_periods=min_obs)
140+
f = r.max if is_max else r.min
141+
test = f(engine=engine)
142+
tm.assert_series_equal(test.Data, ctrl.Data)
143+
144+
# @pytest.mark.parametrize("engine", ["python", "cython", "numba"])
145+
@pytest.mark.parametrize("engine", ["python", "cython"])
146+
@pytest.mark.parametrize(
147+
"seed, n, win_len, indexer_t",
148+
[
149+
(42, 15, 7, ArbitraryWindowIndexer),
150+
],
151+
)
152+
def test_wrong_order(self, engine, seed, n, win_len, indexer_t):
153+
rng = np.random.default_rng(seed)
154+
vals = pd.DataFrame({"Data": rng.random(n)})
155+
156+
ind_obj = indexer_t(rng, len(vals), win_len)
157+
ind_obj._end[[14, 7]] = ind_obj._end[[7, 14]]
158+
159+
f = vals.rolling(ind_obj).max
160+
with pytest.raises(
161+
ValueError, match="Start/End ordering requirement is violated at index 8"
162+
):
163+
f(engine=engine)
164+
165+
166+
def calc_minmax_control(vals, start, end, min_periods, ismax):
167+
func = np.nanmax if ismax else np.nanmin
168+
outp = np.full(vals.shape, np.nan)
169+
for i in range(len(start)):
170+
if start[i] >= end[i]:
171+
outp[i] = np.nan
172+
else:
173+
rng = vals[start[i] : end[i]]
174+
non_nan_cnt = np.count_nonzero(~np.isnan(rng))
175+
if non_nan_cnt >= min_periods:
176+
outp[i] = func(rng)
177+
else:
178+
outp[i] = np.nan
179+
return pd.DataFrame({"Data": outp})

0 commit comments

Comments
 (0)