Skip to content

Commit 9acdcee

Browse files
committed
optimizations
1 parent db28d3a commit 9acdcee

File tree

5 files changed

+83
-92
lines changed

5 files changed

+83
-92
lines changed

hypothesis-python/src/hypothesis/internal/conjecture/data.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,7 @@ def __len__(self) -> int:
492492
return self.__length
493493

494494
def __getitem__(self, i: int) -> Example:
495-
assert isinstance(i, int)
496-
n = len(self)
495+
n = self.__length
497496
if i < -n or i >= n:
498497
raise IndexError(f"Index {i} out of range [-{n}, {n})")
499498
if i < 0:

hypothesis-python/src/hypothesis/internal/conjecture/junkdrawer.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
anything that lives here, please move it."""
1414

1515
import array
16+
from array import ArrayType
1617
import gc
1718
import sys
1819
import time
@@ -22,7 +23,6 @@
2223
Any,
2324
Callable,
2425
Generic,
25-
List,
2626
Literal,
2727
Optional,
2828
TypeVar,
@@ -41,7 +41,7 @@
4141

4242
def array_or_list(
4343
code: str, contents: Iterable[int]
44-
) -> "Union[List[int], array.ArrayType[int]]":
44+
) -> Union[list[int], "ArrayType[int]"]:
4545
if code == "O":
4646
return list(contents)
4747
return array.array(code, contents)
@@ -82,7 +82,7 @@ class IntList(Sequence[int]):
8282

8383
__slots__ = ("__underlying",)
8484

85-
__underlying: "Union[List[int], array.ArrayType[int]]"
85+
__underlying: Union[list[int], "ArrayType[int]"]
8686

8787
def __init__(self, values: Sequence[int] = ()):
8888
for code in ARRAY_CODES:
@@ -116,11 +116,13 @@ def __len__(self) -> int:
116116
def __getitem__(self, i: int) -> int: ... # pragma: no cover
117117

118118
@overload
119-
def __getitem__(self, i: slice) -> "IntList": ... # pragma: no cover
119+
def __getitem__(
120+
self, i: slice
121+
) -> Union[list[int], "ArrayType[int]"]: ... # pragma: no cover
120122

121-
def __getitem__(self, i: Union[int, slice]) -> "Union[int, IntList]":
122-
if isinstance(i, slice):
123-
return IntList(self.__underlying[i])
123+
def __getitem__(
124+
self, i: Union[int, slice]
125+
) -> Union[int, list[int], "ArrayType[int]"]:
124126
return self.__underlying[i]
125127

126128
def __delitem__(self, i: Union[int, slice]) -> None:

hypothesis-python/src/hypothesis/internal/conjecture/utils.py

+69-64
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections import OrderedDict, abc
1717
from collections.abc import Sequence
1818
from functools import lru_cache
19-
from typing import TYPE_CHECKING, List, Optional, TypeVar, Union
19+
from typing import TYPE_CHECKING, Optional, TypeVar, Union
2020

2121
from hypothesis.errors import InvalidArgument
2222
from hypothesis.internal.compat import int_from_bytes
@@ -87,6 +87,73 @@ def check_sample(
8787
return tuple(values)
8888

8989

90+
@lru_cache(64)
91+
def compute_sampler_table(weights: tuple[float, ...]) -> list[tuple[int, int, float]]:
92+
n = len(weights)
93+
table: list[list[int | float | None]] = [[i, None, None] for i in range(n)]
94+
total = sum(weights)
95+
num_type = type(total)
96+
97+
zero = num_type(0) # type: ignore
98+
one = num_type(1) # type: ignore
99+
100+
small: list[int] = []
101+
large: list[int] = []
102+
103+
probabilities = [w / total for w in weights]
104+
scaled_probabilities: list[float] = []
105+
106+
for i, alternate_chance in enumerate(probabilities):
107+
scaled = alternate_chance * n
108+
scaled_probabilities.append(scaled)
109+
if scaled == 1:
110+
table[i][2] = zero
111+
elif scaled < 1:
112+
small.append(i)
113+
else:
114+
large.append(i)
115+
heapq.heapify(small)
116+
heapq.heapify(large)
117+
118+
while small and large:
119+
lo = heapq.heappop(small)
120+
hi = heapq.heappop(large)
121+
122+
assert lo != hi
123+
assert scaled_probabilities[hi] > one
124+
assert table[lo][1] is None
125+
table[lo][1] = hi
126+
table[lo][2] = one - scaled_probabilities[lo]
127+
scaled_probabilities[hi] = (
128+
scaled_probabilities[hi] + scaled_probabilities[lo]
129+
) - one
130+
131+
if scaled_probabilities[hi] < 1:
132+
heapq.heappush(small, hi)
133+
elif scaled_probabilities[hi] == 1:
134+
table[hi][2] = zero
135+
else:
136+
heapq.heappush(large, hi)
137+
while large:
138+
table[large.pop()][2] = zero
139+
while small:
140+
table[small.pop()][2] = zero
141+
142+
new_table: list[tuple[int, int, float]] = []
143+
for base, alternate, alternate_chance in table:
144+
assert isinstance(base, int)
145+
assert isinstance(alternate, int) or alternate is None
146+
assert alternate_chance is not None
147+
if alternate is None:
148+
new_table.append((base, base, alternate_chance))
149+
elif alternate < base:
150+
new_table.append((alternate, base, one - alternate_chance))
151+
else:
152+
new_table.append((base, alternate, alternate_chance))
153+
new_table.sort()
154+
return new_table
155+
156+
90157
class Sampler:
91158
"""Sampler based on Vose's algorithm for the alias method. See
92159
http://www.keithschwarz.com/darts-dice-coins/ for a good explanation.
@@ -109,69 +176,7 @@ class Sampler:
109176

110177
def __init__(self, weights: Sequence[float], *, observe: bool = True):
111178
self.observe = observe
112-
113-
n = len(weights)
114-
table: "list[list[int | float | None]]" = [[i, None, None] for i in range(n)]
115-
total = sum(weights)
116-
num_type = type(total)
117-
118-
zero = num_type(0) # type: ignore
119-
one = num_type(1) # type: ignore
120-
121-
small: "List[int]" = []
122-
large: "List[int]" = []
123-
124-
probabilities = [w / total for w in weights]
125-
scaled_probabilities: "List[float]" = []
126-
127-
for i, alternate_chance in enumerate(probabilities):
128-
scaled = alternate_chance * n
129-
scaled_probabilities.append(scaled)
130-
if scaled == 1:
131-
table[i][2] = zero
132-
elif scaled < 1:
133-
small.append(i)
134-
else:
135-
large.append(i)
136-
heapq.heapify(small)
137-
heapq.heapify(large)
138-
139-
while small and large:
140-
lo = heapq.heappop(small)
141-
hi = heapq.heappop(large)
142-
143-
assert lo != hi
144-
assert scaled_probabilities[hi] > one
145-
assert table[lo][1] is None
146-
table[lo][1] = hi
147-
table[lo][2] = one - scaled_probabilities[lo]
148-
scaled_probabilities[hi] = (
149-
scaled_probabilities[hi] + scaled_probabilities[lo]
150-
) - one
151-
152-
if scaled_probabilities[hi] < 1:
153-
heapq.heappush(small, hi)
154-
elif scaled_probabilities[hi] == 1:
155-
table[hi][2] = zero
156-
else:
157-
heapq.heappush(large, hi)
158-
while large:
159-
table[large.pop()][2] = zero
160-
while small:
161-
table[small.pop()][2] = zero
162-
163-
self.table: "list[tuple[int, int, float]]" = []
164-
for base, alternate, alternate_chance in table:
165-
assert isinstance(base, int)
166-
assert isinstance(alternate, int) or alternate is None
167-
assert alternate_chance is not None
168-
if alternate is None:
169-
self.table.append((base, base, alternate_chance))
170-
elif alternate < base:
171-
self.table.append((alternate, base, one - alternate_chance))
172-
else:
173-
self.table.append((base, alternate, alternate_chance))
174-
self.table.sort()
179+
self.table = compute_sampler_table(tuple(weights))
175180

176181
def sample(
177182
self,

hypothesis-python/tests/conjecture/test_junkdrawer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def test_int_list_extend():
169169

170170
def test_int_list_slice():
171171
x = IntList([1, 2])
172-
assert x[:1] == IntList([1])
173-
assert x[0:2] == IntList([1, 2])
174-
assert x[1:] == IntList([2])
172+
assert list(x[:1]) == [1]
173+
assert list(x[0:2]) == [1, 2]
174+
assert list(x[1:]) == [2]
175175

176176

177177
def test_int_list_del():

hypothesis-python/tests/nocover/test_conjecture_int_list.py

+1-16
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,6 @@ def valid_index(draw):
2323
return draw(st.integers(0, len(machine.model) - 1))
2424

2525

26-
@st.composite
27-
def valid_slice(draw):
28-
machine = draw(st.runner())
29-
result = [
30-
draw(st.integers(0, max(3, len(machine.model) * 2 - 1))) for _ in range(2)
31-
]
32-
result.sort()
33-
return slice(*result)
34-
35-
3626
class IntListRules(RuleBasedStateMachine):
3727
@initialize(ls=st.lists(INTEGERS))
3828
def starting_lists(self, ls):
@@ -52,16 +42,11 @@ def append(self, n):
5242
self.model.append(n)
5343
self.target.append(n)
5444

55-
@rule(i=valid_index() | valid_slice())
45+
@rule(i=valid_index())
5646
def delete(self, i):
5747
del self.model[i]
5848
del self.target[i]
5949

60-
@rule(sl=valid_slice())
61-
def slice(self, sl):
62-
self.model = self.model[sl]
63-
self.target = self.target[sl]
64-
6550
@rule(i=valid_index())
6651
def agree_on_values(self, i):
6752
assert self.model[i] == self.target[i]

0 commit comments

Comments
 (0)