16
16
from collections import OrderedDict , abc
17
17
from collections .abc import Sequence
18
18
from functools import lru_cache
19
- from typing import TYPE_CHECKING , List , Optional , TypeVar , Union
19
+ from typing import TYPE_CHECKING , Optional , TypeVar , Union
20
20
21
21
from hypothesis .errors import InvalidArgument
22
22
from hypothesis .internal .compat import int_from_bytes
@@ -87,6 +87,73 @@ def check_sample(
87
87
return tuple (values )
88
88
89
89
90
+ @lru_cache (64 )
91
+ def compute_sampler_table (weights : tuple [float , ...]) -> list [tuple [int , int , float ]]:
92
+ n = len (weights )
93
+ table : list [list [int | float | None ]] = [[i , None , None ] for i in range (n )]
94
+ total = sum (weights )
95
+ num_type = type (total )
96
+
97
+ zero = num_type (0 ) # type: ignore
98
+ one = num_type (1 ) # type: ignore
99
+
100
+ small : list [int ] = []
101
+ large : list [int ] = []
102
+
103
+ probabilities = [w / total for w in weights ]
104
+ scaled_probabilities : list [float ] = []
105
+
106
+ for i , alternate_chance in enumerate (probabilities ):
107
+ scaled = alternate_chance * n
108
+ scaled_probabilities .append (scaled )
109
+ if scaled == 1 :
110
+ table [i ][2 ] = zero
111
+ elif scaled < 1 :
112
+ small .append (i )
113
+ else :
114
+ large .append (i )
115
+ heapq .heapify (small )
116
+ heapq .heapify (large )
117
+
118
+ while small and large :
119
+ lo = heapq .heappop (small )
120
+ hi = heapq .heappop (large )
121
+
122
+ assert lo != hi
123
+ assert scaled_probabilities [hi ] > one
124
+ assert table [lo ][1 ] is None
125
+ table [lo ][1 ] = hi
126
+ table [lo ][2 ] = one - scaled_probabilities [lo ]
127
+ scaled_probabilities [hi ] = (
128
+ scaled_probabilities [hi ] + scaled_probabilities [lo ]
129
+ ) - one
130
+
131
+ if scaled_probabilities [hi ] < 1 :
132
+ heapq .heappush (small , hi )
133
+ elif scaled_probabilities [hi ] == 1 :
134
+ table [hi ][2 ] = zero
135
+ else :
136
+ heapq .heappush (large , hi )
137
+ while large :
138
+ table [large .pop ()][2 ] = zero
139
+ while small :
140
+ table [small .pop ()][2 ] = zero
141
+
142
+ new_table : list [tuple [int , int , float ]] = []
143
+ for base , alternate , alternate_chance in table :
144
+ assert isinstance (base , int )
145
+ assert isinstance (alternate , int ) or alternate is None
146
+ assert alternate_chance is not None
147
+ if alternate is None :
148
+ new_table .append ((base , base , alternate_chance ))
149
+ elif alternate < base :
150
+ new_table .append ((alternate , base , one - alternate_chance ))
151
+ else :
152
+ new_table .append ((base , alternate , alternate_chance ))
153
+ new_table .sort ()
154
+ return new_table
155
+
156
+
90
157
class Sampler :
91
158
"""Sampler based on Vose's algorithm for the alias method. See
92
159
http://www.keithschwarz.com/darts-dice-coins/ for a good explanation.
@@ -109,69 +176,7 @@ class Sampler:
109
176
110
177
def __init__ (self , weights : Sequence [float ], * , observe : bool = True ):
111
178
self .observe = observe
112
-
113
- n = len (weights )
114
- table : "list[list[int | float | None]]" = [[i , None , None ] for i in range (n )]
115
- total = sum (weights )
116
- num_type = type (total )
117
-
118
- zero = num_type (0 ) # type: ignore
119
- one = num_type (1 ) # type: ignore
120
-
121
- small : "List[int]" = []
122
- large : "List[int]" = []
123
-
124
- probabilities = [w / total for w in weights ]
125
- scaled_probabilities : "List[float]" = []
126
-
127
- for i , alternate_chance in enumerate (probabilities ):
128
- scaled = alternate_chance * n
129
- scaled_probabilities .append (scaled )
130
- if scaled == 1 :
131
- table [i ][2 ] = zero
132
- elif scaled < 1 :
133
- small .append (i )
134
- else :
135
- large .append (i )
136
- heapq .heapify (small )
137
- heapq .heapify (large )
138
-
139
- while small and large :
140
- lo = heapq .heappop (small )
141
- hi = heapq .heappop (large )
142
-
143
- assert lo != hi
144
- assert scaled_probabilities [hi ] > one
145
- assert table [lo ][1 ] is None
146
- table [lo ][1 ] = hi
147
- table [lo ][2 ] = one - scaled_probabilities [lo ]
148
- scaled_probabilities [hi ] = (
149
- scaled_probabilities [hi ] + scaled_probabilities [lo ]
150
- ) - one
151
-
152
- if scaled_probabilities [hi ] < 1 :
153
- heapq .heappush (small , hi )
154
- elif scaled_probabilities [hi ] == 1 :
155
- table [hi ][2 ] = zero
156
- else :
157
- heapq .heappush (large , hi )
158
- while large :
159
- table [large .pop ()][2 ] = zero
160
- while small :
161
- table [small .pop ()][2 ] = zero
162
-
163
- self .table : "list[tuple[int, int, float]]" = []
164
- for base , alternate , alternate_chance in table :
165
- assert isinstance (base , int )
166
- assert isinstance (alternate , int ) or alternate is None
167
- assert alternate_chance is not None
168
- if alternate is None :
169
- self .table .append ((base , base , alternate_chance ))
170
- elif alternate < base :
171
- self .table .append ((alternate , base , one - alternate_chance ))
172
- else :
173
- self .table .append ((base , alternate , alternate_chance ))
174
- self .table .sort ()
179
+ self .table = compute_sampler_table (tuple (weights ))
175
180
176
181
def sample (
177
182
self ,
0 commit comments