Skip to content

Commit 8b93478

Browse files
Randomized deterministic unit tests replaced with 10-row prebuilt tests as requested by @mroeschke
1 parent 7679445 commit 8b93478

File tree

2 files changed

+49
-199
lines changed

2 files changed

+49
-199
lines changed

Diff for: pandas/tests/window/test_numba.py

+8-50
Original file line numberDiff line numberDiff line change
@@ -583,58 +583,16 @@ def test_npfunc_no_warnings():
583583
df.col1.rolling(2).apply(np.prod, raw=True, engine="numba")
584584

585585

586-
from .test_rolling import (
587-
ArbitraryWindowIndexer,
588-
CustomLengthWindowIndexer,
589-
)
586+
from .test_rolling import TestMinMax
590587

591588

592589
@td.skip_if_no("numba")
593-
class TestMinMax:
594-
@pytest.mark.parametrize("is_max", [True, False])
595-
@pytest.mark.parametrize(
596-
"seed, n, win_len, min_obs, frac_nan, indexer_t",
597-
[
598-
(42, 1000, 80, 15, 0.3, CustomLengthWindowIndexer),
599-
(52, 1000, 80, 15, 0.3, ArbitraryWindowIndexer),
600-
(1984, 1000, 40, 25, 0.3, None),
601-
],
602-
)
603-
def test_minmax(self, is_max, seed, n, win_len, min_obs, frac_nan, indexer_t):
604-
if seed is not None and isinstance(seed, np.random._generator.Generator):
605-
rng = np.random.default_rng(seed)
606-
rng.bit_generator.state = seed.bit_generator.state
607-
else:
608-
rng = np.random.default_rng(seed)
609-
610-
vals = DataFrame({"Data": rng.random(n)})
611-
if frac_nan > 0:
612-
is_nan = rng.random(len(vals)) < frac_nan
613-
vals.Data = np.where(is_nan, np.nan, vals.Data)
614-
615-
ind_param = indexer_t(rng, len(vals), win_len) if indexer_t else win_len
590+
class TestMinMaxNumba:
591+
parent = TestMinMax()
616592

617-
r = vals.rolling(ind_param, min_periods=min_obs)
618-
f = r.max if is_max else r.min
619-
test_cython = f(engine="cython")
620-
test_numba = f(engine="numba")
621-
tm.assert_series_equal(test_numba.Data, test_cython.Data)
593+
@pytest.mark.parametrize("is_max, has_nan, exp_list", TestMinMax.TestData)
594+
def test_minmax(self, is_max, has_nan, exp_list):
595+
TestMinMaxNumba.parent.test_minmax(is_max, has_nan, exp_list, "numba")
622596

623-
@pytest.mark.parametrize(
624-
"seed, n, win_len, indexer_t",
625-
[
626-
(42, 15, 7, ArbitraryWindowIndexer),
627-
],
628-
)
629-
def test_wrong_order(self, seed, n, win_len, indexer_t):
630-
rng = np.random.default_rng(seed)
631-
vals = DataFrame({"Data": rng.random(n)})
632-
633-
ind_obj = indexer_t(rng, len(vals), win_len)
634-
ind_obj._end[[14, 7]] = ind_obj._end[[7, 14]]
635-
636-
f = vals.rolling(ind_obj).max
637-
with pytest.raises(
638-
ValueError, match="Start/End ordering requirement is violated at index 8"
639-
):
640-
f(engine="numba")
597+
def test_wrong_order(self):
598+
TestMinMaxNumba.parent.test_wrong_order("numba")

Diff for: pandas/tests/window/test_rolling.py

+41-149
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
datetime,
33
timedelta,
44
)
5-
from typing import Any
65

76
import numpy as np
87
import pytest
@@ -1949,54 +1948,10 @@ def test_rolling_timedelta_window_non_nanoseconds(unit, tz):
19491948
tm.assert_frame_equal(ref_df, df)
19501949

19511950

1952-
class StandardWindowIndexer(BaseIndexer):
1953-
def __init__(self, n, win_len):
1954-
self.n = n
1955-
self.win_len = win_len
1956-
super().__init__()
1957-
1958-
def get_window_bounds(
1959-
self, num_values=None, min_periods=None, center=None, closed=None, step=None
1960-
):
1961-
if num_values is None:
1962-
num_values = self.n
1963-
end = np.arange(num_values, dtype="int64") + 1
1964-
start = np.clip(end - self.win_len, 0, num_values)
1965-
return start, end
1966-
1967-
1968-
class CustomLengthWindowIndexer(BaseIndexer):
1969-
def __init__(self, rnd: np.random.Generator, n, win_len):
1970-
self.window = rnd.integers(win_len, size=n)
1971-
super().__init__()
1972-
1973-
def get_window_bounds(
1974-
self, num_values=None, min_periods=None, center=None, closed=None, step=None
1975-
):
1976-
if num_values is None:
1977-
num_values = len(self.window)
1978-
end = np.arange(num_values, dtype="int64") + 1
1979-
start = np.clip(end - self.window, 0, num_values)
1980-
return start, end
1981-
1982-
1983-
class ArbitraryWindowIndexer(BaseIndexer):
1984-
def __init__(self, rnd: np.random.Generator, n, win_len):
1985-
start = rnd.integers(n, size=n)
1986-
win_len = rnd.integers(win_len, size=n)
1987-
end = np.where(start - win_len >= 0, start - win_len, start + win_len)
1988-
1989-
(start, end) = (
1990-
np.where(end >= start, start, end),
1991-
np.where(end >= start, end, start),
1992-
)
1993-
1994-
# It is extremely unlikely that a random array would come sorted,
1995-
# so we proceed with sort without checking if it is sorted.
1996-
prm = sorted(range(len(start)), key=lambda i: (end[i], start[i]))
1997-
1998-
self._start = np.array(start)[prm]
1999-
self._end = np.array(end)[prm]
1951+
class PrescribedWindowIndexer(BaseIndexer):
1952+
def __init__(self, start, end):
1953+
self._start = start
1954+
self._end = end
20001955
super().__init__()
20011956

20021957
def get_window_bounds(
@@ -2010,109 +1965,46 @@ def get_window_bounds(
20101965

20111966

20121967
class TestMinMax:
2013-
# Pytest cache will not be a good choice here, because it appears
2014-
# pytest persists data on disk, and we are not really interested
2015-
# in flooding your hard drive with random numbers.
2016-
# Thus we just cache control data in memory to avoid repetititve calculations.
2017-
class Cache:
2018-
def __init__(self) -> None:
2019-
self.ctrl: dict[Any, Any] = {}
2020-
2021-
@pytest.fixture(scope="class")
2022-
def cache(self) -> Cache:
2023-
return self.Cache()
2024-
2025-
@pytest.mark.parametrize("is_max", [True, False])
2026-
# @pytest.mark.parametrize("engine", ["python", "cython", "numba"])
2027-
@pytest.mark.parametrize("engine", ["cython"])
2028-
@pytest.mark.parametrize(
2029-
"seed, n, win_len, min_obs, frac_nan, indexer_t",
2030-
[
2031-
(42, 1000, 80, 15, 0.3, CustomLengthWindowIndexer),
2032-
(52, 1000, 80, 15, 0.3, ArbitraryWindowIndexer),
2033-
(1984, 1000, 40, 25, 0.3, None),
2034-
],
2035-
)
2036-
def test_minmax(
2037-
self, is_max, engine, seed, n, win_len, min_obs, frac_nan, indexer_t, cache
2038-
):
2039-
if seed is not None and isinstance(seed, np.random._generator.Generator):
2040-
rng = np.random.default_rng(seed)
2041-
rng.bit_generator.state = seed.bit_generator.state
2042-
else:
2043-
rng = np.random.default_rng(seed)
2044-
2045-
if seed is None or isinstance(seed, np.random._generator.Generator):
2046-
rng_state_for_key = (
2047-
rng.bit_generator.state["bit_generator"],
2048-
rng.bit_generator.state["state"]["state"],
2049-
rng.bit_generator.state["state"]["inc"],
2050-
rng.bit_generator.state["has_uint32"],
2051-
rng.bit_generator.state["uinteger"],
2052-
)
2053-
else:
2054-
rng_state_for_key = seed
2055-
self.last_rng_state = rng.bit_generator.state
2056-
vals = DataFrame({"Data": rng.random(n)})
2057-
if frac_nan > 0:
2058-
is_nan = rng.random(len(vals)) < frac_nan
2059-
vals.Data = np.where(is_nan, np.nan, vals.Data)
2060-
2061-
ind_obj = (
2062-
indexer_t(rng, len(vals), win_len)
2063-
if indexer_t
2064-
else StandardWindowIndexer(len(vals), win_len)
2065-
)
2066-
ind_param = ind_obj if indexer_t else win_len
1968+
TestData = [
1969+
(True, False, [3.0, 5.0, 2.0, 5.0, 1.0, 5.0, 6.0, 7.0, 8.0, 9.0]),
1970+
(True, True, [3.0, 4.0, 2.0, 4.0, 1.0, 4.0, 6.0, 7.0, 7.0, 9.0]),
1971+
(False, False, [3.0, 2.0, 2.0, 1.0, 1.0, 0.0, 0.0, 0.0, 7.0, 0.0]),
1972+
(False, True, [3.0, 2.0, 2.0, 1.0, 1.0, 1.0, 6.0, 6.0, 7.0, 1.0]),
1973+
]
20671974

2068-
(start, end) = ind_obj.get_window_bounds()
2069-
ctrl_key = (is_max, rng_state_for_key, n, win_len, min_obs, frac_nan, indexer_t)
2070-
if ctrl_key in cache.ctrl:
2071-
ctrl = cache.ctrl[ctrl_key]
1975+
@pytest.mark.parametrize("is_max, has_nan, exp_list", TestData)
1976+
def test_minmax(self, is_max, has_nan, exp_list, engine=None):
1977+
nan_idx = [0, 5, 8]
1978+
df = DataFrame(
1979+
{
1980+
"data": [5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 6.0, 7.0, 8.0, 9.0],
1981+
"start": [2, 0, 3, 0, 4, 0, 5, 5, 7, 3],
1982+
"end": [3, 4, 4, 5, 5, 6, 7, 8, 9, 10],
1983+
}
1984+
)
1985+
if has_nan:
1986+
df.loc[nan_idx, "data"] = np.nan
1987+
expected = Series(exp_list, name="data")
1988+
r = df.data.rolling(
1989+
PrescribedWindowIndexer(df.start.to_numpy(), df.end.to_numpy())
1990+
)
1991+
if is_max:
1992+
result = r.max(engine=engine)
20721993
else:
2073-
# This is brute force calculation, and may get expensive when n is
2074-
# large, so we cache it.
2075-
ctrl = calc_minmax_control(vals.Data, start, end, min_obs, is_max)
2076-
cache.ctrl[ctrl_key] = ctrl
2077-
2078-
r = vals.rolling(ind_param, min_periods=min_obs)
2079-
f = r.max if is_max else r.min
2080-
test = f(engine=engine)
2081-
tm.assert_series_equal(test.Data, ctrl.Data)
2082-
2083-
# @pytest.mark.parametrize("engine", ["python", "cython", "numba"])
2084-
@pytest.mark.parametrize("engine", ["cython"])
2085-
@pytest.mark.parametrize(
2086-
"seed, n, win_len, indexer_t",
2087-
[
2088-
(42, 15, 7, ArbitraryWindowIndexer),
2089-
],
2090-
)
2091-
def test_wrong_order(self, engine, seed, n, win_len, indexer_t):
2092-
rng = np.random.default_rng(seed)
2093-
vals = DataFrame({"Data": rng.random(n)})
1994+
result = r.min(engine=engine)
1995+
1996+
tm.assert_series_equal(result, expected)
1997+
1998+
def test_wrong_order(self, engine=None):
1999+
start = np.array(range(5), dtype=np.int64)
2000+
end = start + 1
2001+
end[3] = end[2]
2002+
start[3] = start[2] - 1
20942003

2095-
ind_obj = indexer_t(rng, len(vals), win_len)
2096-
ind_obj._end[[14, 7]] = ind_obj._end[[7, 14]]
2004+
df = DataFrame({"data": start * 1.0, "start": start, "end": end})
20972005

2098-
f = vals.rolling(ind_obj).max
2006+
r = df.data.rolling(PrescribedWindowIndexer(start, end))
20992007
with pytest.raises(
2100-
ValueError, match="Start/End ordering requirement is violated at index 8"
2008+
ValueError, match="Start/End ordering requirement is violated at index 3"
21012009
):
2102-
f(engine=engine)
2103-
2104-
2105-
def calc_minmax_control(vals, start, end, min_periods, ismax):
2106-
func = np.nanmax if ismax else np.nanmin
2107-
outp = np.full(vals.shape, np.nan)
2108-
for i in range(len(start)):
2109-
if start[i] >= end[i]:
2110-
outp[i] = np.nan
2111-
else:
2112-
rng = vals[start[i] : end[i]]
2113-
non_nan_cnt = np.count_nonzero(~np.isnan(rng))
2114-
if non_nan_cnt >= min_periods:
2115-
outp[i] = func(rng)
2116-
else:
2117-
outp[i] = np.nan
2118-
return DataFrame({"Data": outp})
2010+
r.max(engine=engine)

0 commit comments

Comments
 (0)