|
| 1 | +from typing import ( |
| 2 | + Any, |
| 3 | +) |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pytest |
| 7 | + |
| 8 | +import pandas as pd |
| 9 | +from pandas import api |
| 10 | +import pandas._testing as tm |
| 11 | + |
| 12 | + |
| 13 | +class StandardWindowIndexer(api.indexers.BaseIndexer): |
| 14 | + def __init__(self, n, win_len): |
| 15 | + self.n = n |
| 16 | + self.win_len = win_len |
| 17 | + super().__init__() |
| 18 | + |
| 19 | + def get_window_bounds( |
| 20 | + self, num_values=None, min_periods=None, center=None, closed=None, step=None |
| 21 | + ): |
| 22 | + if num_values is None: |
| 23 | + num_values = self.n |
| 24 | + end = np.arange(num_values, dtype="int64") + 1 |
| 25 | + start = np.clip(end - self.win_len, 0, num_values) |
| 26 | + return start, end |
| 27 | + |
| 28 | + |
| 29 | +class CustomLengthWindowIndexer(api.indexers.BaseIndexer): |
| 30 | + def __init__(self, rnd: np.random.Generator, n, win_len): |
| 31 | + self.window = rnd.integers(win_len, size=n) |
| 32 | + super().__init__() |
| 33 | + |
| 34 | + def get_window_bounds( |
| 35 | + self, num_values=None, min_periods=None, center=None, closed=None, step=None |
| 36 | + ): |
| 37 | + if num_values is None: |
| 38 | + num_values = len(self.window) |
| 39 | + end = np.arange(num_values, dtype="int64") + 1 |
| 40 | + start = np.clip(end - self.window, 0, num_values) |
| 41 | + return start, end |
| 42 | + |
| 43 | + |
| 44 | +class ArbitraryWindowIndexer(api.indexers.BaseIndexer): |
| 45 | + def __init__(self, rnd: np.random.Generator, n, win_len): |
| 46 | + start = rnd.integers(n, size=n) |
| 47 | + win_len = rnd.integers(win_len, size=n) |
| 48 | + end = np.where(start - win_len >= 0, start - win_len, start + win_len) |
| 49 | + |
| 50 | + (start, end) = ( |
| 51 | + np.where(end >= start, start, end), |
| 52 | + np.where(end >= start, end, start), |
| 53 | + ) |
| 54 | + |
| 55 | + # It is extremely unlikely that a random array would come sorted, |
| 56 | + # so we proceed with sort without checking if it is sorted. |
| 57 | + prm = sorted(range(len(start)), key=lambda i: (end[i], start[i])) |
| 58 | + |
| 59 | + self._start = np.array(start)[prm] |
| 60 | + self._end = np.array(end)[prm] |
| 61 | + super().__init__() |
| 62 | + |
| 63 | + def get_window_bounds( |
| 64 | + self, num_values=None, min_periods=None, center=None, closed=None, step=None |
| 65 | + ): |
| 66 | + if num_values is None: |
| 67 | + num_values = len(self._start) |
| 68 | + start = np.clip(self._start, 0, num_values) |
| 69 | + end = np.clip(self._end, 0, num_values) |
| 70 | + return start, end |
| 71 | + |
| 72 | + |
| 73 | +class TestMinMax: |
| 74 | + # Pytest cache will not be a good choice here, because it appears |
| 75 | + # pytest persists data on disk, and we are not really interested |
| 76 | + # in flooding your hard drive with random numbers. |
| 77 | + # Thus we just cache control data in memory to avoid repetititve calculations. |
| 78 | + class Cache: |
| 79 | + def __init__(self) -> None: |
| 80 | + self.ctrl: dict[Any, Any] = {} |
| 81 | + |
| 82 | + @pytest.fixture(scope="class") |
| 83 | + def cache(self) -> Cache: |
| 84 | + return self.Cache() |
| 85 | + |
| 86 | + @pytest.mark.parametrize("is_max", [True, False]) |
| 87 | + # @pytest.mark.parametrize("engine", ["python", "cython", "numba"]) |
| 88 | + @pytest.mark.parametrize("engine", ["python", "cython"]) |
| 89 | + @pytest.mark.parametrize( |
| 90 | + "seed, n, win_len, min_obs, frac_nan, indexer_t", |
| 91 | + [ |
| 92 | + (42, 1000, 80, 15, 0.3, CustomLengthWindowIndexer), |
| 93 | + (52, 1000, 80, 15, 0.3, ArbitraryWindowIndexer), |
| 94 | + (1984, 1000, 40, 25, 0.3, None), |
| 95 | + ], |
| 96 | + ) |
| 97 | + def test_minmax( |
| 98 | + self, is_max, engine, seed, n, win_len, min_obs, frac_nan, indexer_t, cache |
| 99 | + ): |
| 100 | + if seed is not None and isinstance(seed, np.random._generator.Generator): |
| 101 | + rng = np.random.default_rng(seed) |
| 102 | + rng.bit_generator.state = seed.bit_generator.state |
| 103 | + else: |
| 104 | + rng = np.random.default_rng(seed) |
| 105 | + |
| 106 | + if seed is None or isinstance(seed, np.random._generator.Generator): |
| 107 | + rng_state_for_key = ( |
| 108 | + rng.bit_generator.state["bit_generator"], |
| 109 | + rng.bit_generator.state["state"]["state"], |
| 110 | + rng.bit_generator.state["state"]["inc"], |
| 111 | + rng.bit_generator.state["has_uint32"], |
| 112 | + rng.bit_generator.state["uinteger"], |
| 113 | + ) |
| 114 | + else: |
| 115 | + rng_state_for_key = seed |
| 116 | + self.last_rng_state = rng.bit_generator.state |
| 117 | + vals = pd.DataFrame({"Data": rng.random(n)}) |
| 118 | + if frac_nan > 0: |
| 119 | + is_nan = rng.random(len(vals)) < frac_nan |
| 120 | + vals.Data = np.where(is_nan, np.nan, vals.Data) |
| 121 | + |
| 122 | + ind_obj = ( |
| 123 | + indexer_t(rng, len(vals), win_len) |
| 124 | + if indexer_t |
| 125 | + else StandardWindowIndexer(len(vals), win_len) |
| 126 | + ) |
| 127 | + ind_param = ind_obj if indexer_t else win_len |
| 128 | + |
| 129 | + (start, end) = ind_obj.get_window_bounds() |
| 130 | + ctrl_key = (is_max, rng_state_for_key, n, win_len, min_obs, frac_nan, indexer_t) |
| 131 | + if ctrl_key in cache.ctrl: |
| 132 | + ctrl = cache.ctrl[ctrl_key] |
| 133 | + else: |
| 134 | + # This is brute force calculation, and may get expensive when n is |
| 135 | + # large, so we cache it. |
| 136 | + ctrl = calc_minmax_control(vals.Data, start, end, min_obs, is_max) |
| 137 | + cache.ctrl[ctrl_key] = ctrl |
| 138 | + |
| 139 | + r = vals.rolling(ind_param, min_periods=min_obs) |
| 140 | + f = r.max if is_max else r.min |
| 141 | + test = f(engine=engine) |
| 142 | + tm.assert_series_equal(test.Data, ctrl.Data) |
| 143 | + |
| 144 | + # @pytest.mark.parametrize("engine", ["python", "cython", "numba"]) |
| 145 | + @pytest.mark.parametrize("engine", ["python", "cython"]) |
| 146 | + @pytest.mark.parametrize( |
| 147 | + "seed, n, win_len, indexer_t", |
| 148 | + [ |
| 149 | + (42, 15, 7, ArbitraryWindowIndexer), |
| 150 | + ], |
| 151 | + ) |
| 152 | + def test_wrong_order(self, engine, seed, n, win_len, indexer_t): |
| 153 | + rng = np.random.default_rng(seed) |
| 154 | + vals = pd.DataFrame({"Data": rng.random(n)}) |
| 155 | + |
| 156 | + ind_obj = indexer_t(rng, len(vals), win_len) |
| 157 | + ind_obj._end[[14, 7]] = ind_obj._end[[7, 14]] |
| 158 | + |
| 159 | + f = vals.rolling(ind_obj).max |
| 160 | + with pytest.raises( |
| 161 | + ValueError, match="Start/End ordering requirement is violated at index 8" |
| 162 | + ): |
| 163 | + f(engine=engine) |
| 164 | + |
| 165 | + |
| 166 | +def calc_minmax_control(vals, start, end, min_periods, ismax): |
| 167 | + func = np.nanmax if ismax else np.nanmin |
| 168 | + outp = np.full(vals.shape, np.nan) |
| 169 | + for i in range(len(start)): |
| 170 | + if start[i] >= end[i]: |
| 171 | + outp[i] = np.nan |
| 172 | + else: |
| 173 | + rng = vals[start[i] : end[i]] |
| 174 | + non_nan_cnt = np.count_nonzero(~np.isnan(rng)) |
| 175 | + if non_nan_cnt >= min_periods: |
| 176 | + outp[i] = func(rng) |
| 177 | + else: |
| 178 | + outp[i] = np.nan |
| 179 | + return pd.DataFrame({"Data": outp}) |
0 commit comments