Skip to content

Commit d603bd4

Browse files
authored
Merge pull request #3987 from tybug/lazy-sequence-copy-pop
Implement `LazySequenceCopy.pop(i)`
2 parents c42bbc7 + 9368da7 commit d603bd4

File tree

8 files changed

+83
-50
lines changed

8 files changed

+83
-50
lines changed

hypothesis-python/RELEASE.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
RELEASE_TYPE: patch
2+
3+
This patch improves our shrinking of unique collections, such as :func:`~hypothesis.strategies.dictionaries`,
4+
:func:`~hypothesis.strategies.sets`, and :func:`~hypothesis.strategies.lists` with ``unique=True``.

hypothesis-python/src/hypothesis/internal/conjecture/choicetree.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from random import Random
1313
from typing import Callable, Dict, Iterable, List, Optional, Sequence
1414

15-
from hypothesis.internal.conjecture.junkdrawer import LazySequenceCopy, pop_random
15+
from hypothesis.internal.conjecture.junkdrawer import LazySequenceCopy
1616

1717

1818
def prefix_selection_order(
@@ -41,7 +41,8 @@ def random_selection_order(random: Random) -> Callable[[int, int], Iterable[int]
4141
def selection_order(depth: int, n: int) -> Iterable[int]:
4242
pending = LazySequenceCopy(range(n))
4343
while pending:
44-
yield pop_random(random, pending)
44+
i = random.randrange(0, len(pending))
45+
yield pending.pop(i)
4546

4647
return selection_order
4748

hypothesis-python/src/hypothesis/internal/conjecture/junkdrawer.py

+39-20
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
overload,
3232
)
3333

34+
from sortedcontainers import SortedList
35+
3436
from hypothesis.errors import HypothesisWarning
3537

3638
ARRAY_CODES = ["B", "H", "I", "L", "Q", "O"]
@@ -199,46 +201,71 @@ class LazySequenceCopy:
199201
in O(1) time. The full list API is not supported yet but there's no reason
200202
in principle it couldn't be."""
201203

202-
__mask: Optional[Dict[int, int]]
203-
204204
def __init__(self, values: Sequence[int]):
205205
self.__values = values
206206
self.__len = len(values)
207-
self.__mask = None
207+
self.__mask: Optional[Dict[int, int]] = None
208+
self.__popped_indices: Optional[SortedList] = None
208209

209210
def __len__(self) -> int:
210-
return self.__len
211+
if self.__popped_indices is None:
212+
return self.__len
213+
return self.__len - len(self.__popped_indices)
211214

212-
def pop(self) -> int:
215+
def pop(self, i: int = -1) -> int:
213216
if len(self) == 0:
214217
raise IndexError("Cannot pop from empty list")
215-
result = self[-1]
216-
self.__len -= 1
218+
i = self.__underlying_index(i)
219+
220+
v = None
217221
if self.__mask is not None:
218-
self.__mask.pop(self.__len, None)
219-
return result
222+
v = self.__mask.pop(i, None)
223+
if v is None:
224+
v = self.__values[i]
225+
226+
if self.__popped_indices is None:
227+
self.__popped_indices = SortedList()
228+
self.__popped_indices.add(i)
229+
return v
220230

221231
def __getitem__(self, i: int) -> int:
222-
i = self.__check_index(i)
232+
i = self.__underlying_index(i)
233+
223234
default = self.__values[i]
224235
if self.__mask is None:
225236
return default
226237
else:
227238
return self.__mask.get(i, default)
228239

229240
def __setitem__(self, i: int, v: int) -> None:
230-
i = self.__check_index(i)
241+
i = self.__underlying_index(i)
231242
if self.__mask is None:
232243
self.__mask = {}
233244
self.__mask[i] = v
234245

235-
def __check_index(self, i: int) -> int:
246+
def __underlying_index(self, i: int) -> int:
236247
n = len(self)
237248
if i < -n or i >= n:
238249
raise IndexError(f"Index {i} out of range [0, {n})")
239250
if i < 0:
240251
i += n
241252
assert 0 <= i < n
253+
254+
if self.__popped_indices is not None:
255+
# given an index i in the popped representation of the list, compute
256+
# its corresponding index in the underlying list. given
257+
# l = [1, 4, 2, 10, 188]
258+
# l.pop(3)
259+
# l.pop(1)
260+
# assert l == [1, 2, 188]
261+
#
262+
# we want l[i] == self.__values[f(i)], where f is this function.
263+
assert len(self.__popped_indices) <= len(self.__values)
264+
265+
for idx in self.__popped_indices:
266+
if idx > i:
267+
break
268+
i += 1
242269
return i
243270

244271

@@ -345,14 +372,6 @@ def find_integer(f: Callable[[int], bool]) -> int:
345372
return lo
346373

347374

348-
def pop_random(random: Random, seq: LazySequenceCopy) -> int:
349-
"""Remove and return a random element of seq. This runs in O(1) but leaves
350-
the sequence in an arbitrary order."""
351-
i = random.randrange(0, len(seq))
352-
swap(seq, i, len(seq) - 1)
353-
return seq.pop()
354-
355-
356375
class NotFound(Exception):
357376
pass
358377

hypothesis-python/src/hypothesis/internal/conjecture/pareto.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ def add(self, data):
179179
failures = 0
180180
while i + 1 < len(front) and failures < 10:
181181
j = self.__random.randrange(i + 1, len(front))
182-
swap(front, j, len(front) - 1)
183-
candidate = front.pop()
182+
candidate = front.pop(j)
184183
dom = dominance(data, candidate)
185184
assert dom != DominanceRelation.RIGHT_DOMINATES
186185
if dom == DominanceRelation.LEFT_DOMINATES:

hypothesis-python/src/hypothesis/strategies/_internal/collections.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,8 @@ def do_draw(self, data):
287287
remaining = LazySequenceCopy(self.element_strategy.elements)
288288

289289
while remaining and should_draw.more():
290-
i = len(remaining) - 1
291-
j = data.draw_integer(0, i)
292-
if j != i:
293-
remaining[i], remaining[j] = remaining[j], remaining[i]
294-
value = self.element_strategy._transform(remaining.pop())
295-
290+
j = data.draw_integer(0, len(remaining) - 1)
291+
value = self.element_strategy._transform(remaining.pop(j))
296292
if value is not filter_not_satisfied and all(
297293
key(value) not in seen for key, seen in zip(self.keys, seen_sets)
298294
):

hypothesis-python/tests/conjecture/test_junkdrawer.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

11+
import copy
1112
import inspect
1213

1314
import pytest
@@ -79,22 +80,31 @@ def test_clamp(lower, value, upper):
7980
assert clamped == upper
8081

8182

82-
def test_pop_without_mask():
83-
y = [1, 2, 3]
84-
x = LazySequenceCopy(y)
85-
x.pop()
86-
assert list(x) == [1, 2]
87-
assert y == [1, 2, 3]
88-
89-
90-
def test_pop_with_mask():
91-
y = [1, 2, 3]
92-
x = LazySequenceCopy(y)
93-
x[-1] = 5
94-
t = x.pop()
95-
assert t == 5
96-
assert list(x) == [1, 2]
97-
assert y == [1, 2, 3]
83+
# this would be more robust as a stateful test, where each rule is a list operation
84+
# on (1) the canonical python list and (2) its LazySequenceCopy. We would assert
85+
# that the return values and lists match after each rule, and the original list
86+
# is unmodified.
87+
@pytest.mark.parametrize("should_mask", [True, False])
88+
@given(lst=st.lists(st.integers(), min_size=1), data=st.data())
89+
def test_pop_sequence_copy(lst, data, should_mask):
90+
original = copy.copy(lst)
91+
pop_i = data.draw(st.integers(0, len(lst) - 1))
92+
if should_mask:
93+
mask_i = data.draw(st.integers(0, len(lst) - 1))
94+
mask_value = data.draw(st.integers())
95+
96+
def pop(l):
97+
if should_mask:
98+
l[mask_i] = mask_value
99+
return l.pop(pop_i)
100+
101+
expected = copy.copy(lst)
102+
l = LazySequenceCopy(lst)
103+
104+
assert pop(expected) == pop(l)
105+
assert list(l) == expected
106+
# modifications to the LazySequenceCopy should not modify the original list
107+
assert original == lst
98108

99109

100110
def test_assignment():

hypothesis-python/tests/nocover/test_sampled_from.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ def test_flags_minimize_to_first_named_flag():
124124

125125

126126
def test_flags_minimizes_bit_count():
127-
shrunk = minimal(st.sampled_from(LargeFlag), lambda f: bit_count(f.value) > 1)
128-
# Ideal would be (bit0 | bit1), but:
129-
# minimal(st.sets(st.sampled_from(range(10)), min_size=3)) == {0, 8, 9} # not {0, 1, 2}
130-
assert shrunk == LargeFlag.bit0 | LargeFlag.bit63 # documents actual behaviour
127+
assert (
128+
minimal(st.sampled_from(LargeFlag), lambda f: bit_count(f.value) > 1)
129+
== LargeFlag.bit0 | LargeFlag.bit1
130+
)
131131

132132

133133
def test_flags_finds_all_bits_set():

hypothesis-python/tests/quality/test_shrink_quality.py

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def test_minimize_sets_of_sets():
107107
assert any(s != t and t.issubset(s) for t in set_of_sets)
108108

109109

110+
def test_minimize_sets_sampled_from():
111+
assert minimal(st.sets(st.sampled_from(range(10)), min_size=3)) == {0, 1, 2}
112+
113+
110114
def test_can_simplify_flatmap_with_bounded_left_hand_size():
111115
assert (
112116
minimal(booleans().flatmap(lambda x: lists(just(x))), lambda x: len(x) >= 10)

0 commit comments

Comments
 (0)