Skip to content

Commit 66d84b5

Browse files
authored
Merge pull request #4001 from tybug/shrinker-adjacent-improvements
Shrinker-adjacent improvements
2 parents 259813a + ff523f7 commit 66d84b5

File tree

8 files changed

+97
-42
lines changed

8 files changed

+97
-42
lines changed

hypothesis-python/RELEASE.rst

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
RELEASE_TYPE: patch
2+
3+
This patch fixes one of our shrinking passes getting into a rare ``O(n)`` case instead of ``O(log(n))``.

hypothesis-python/src/hypothesis/internal/conjecture/data.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ class BooleanKWargs(TypedDict):
136136
IntegerKWargs, FloatKWargs, StringKWargs, BytesKWargs, BooleanKWargs
137137
]
138138
IRTypeName: TypeAlias = Literal["integer", "string", "boolean", "float", "bytes"]
139-
InvalidAt: TypeAlias = Tuple[IRTypeName, IRKWargsType]
139+
# ir_type, kwargs, forced
140+
InvalidAt: TypeAlias = Tuple[IRTypeName, IRKWargsType, Optional[IRType]]
140141

141142

142143
class ExtraInformation:
@@ -2084,7 +2085,7 @@ def draw_integer(
20842085
)
20852086

20862087
if self.ir_tree_nodes is not None and observe:
2087-
node = self._pop_ir_tree_node("integer", kwargs)
2088+
node = self._pop_ir_tree_node("integer", kwargs, forced=forced)
20882089
if forced is None:
20892090
assert isinstance(node.value, int)
20902091
forced = node.value
@@ -2141,7 +2142,7 @@ def draw_float(
21412142
)
21422143

21432144
if self.ir_tree_nodes is not None and observe:
2144-
node = self._pop_ir_tree_node("float", kwargs)
2145+
node = self._pop_ir_tree_node("float", kwargs, forced=forced)
21452146
if forced is None:
21462147
assert isinstance(node.value, float)
21472148
forced = node.value
@@ -2183,7 +2184,7 @@ def draw_string(
21832184
},
21842185
)
21852186
if self.ir_tree_nodes is not None and observe:
2186-
node = self._pop_ir_tree_node("string", kwargs)
2187+
node = self._pop_ir_tree_node("string", kwargs, forced=forced)
21872188
if forced is None:
21882189
assert isinstance(node.value, str)
21892190
forced = node.value
@@ -2219,7 +2220,7 @@ def draw_bytes(
22192220
kwargs: BytesKWargs = self._pooled_kwargs("bytes", {"size": size})
22202221

22212222
if self.ir_tree_nodes is not None and observe:
2222-
node = self._pop_ir_tree_node("bytes", kwargs)
2223+
node = self._pop_ir_tree_node("bytes", kwargs, forced=forced)
22232224
if forced is None:
22242225
assert isinstance(node.value, bytes)
22252226
forced = node.value
@@ -2261,7 +2262,7 @@ def draw_boolean(
22612262
kwargs: BooleanKWargs = self._pooled_kwargs("boolean", {"p": p})
22622263

22632264
if self.ir_tree_nodes is not None and observe:
2264-
node = self._pop_ir_tree_node("boolean", kwargs)
2265+
node = self._pop_ir_tree_node("boolean", kwargs, forced=forced)
22652266
if forced is None:
22662267
assert isinstance(node.value, bool)
22672268
forced = node.value
@@ -2302,7 +2303,9 @@ def _pooled_kwargs(self, ir_type, kwargs):
23022303
POOLED_KWARGS_CACHE[key] = kwargs
23032304
return kwargs
23042305

2305-
def _pop_ir_tree_node(self, ir_type: IRTypeName, kwargs: IRKWargsType) -> IRNode:
2306+
def _pop_ir_tree_node(
2307+
self, ir_type: IRTypeName, kwargs: IRKWargsType, *, forced: Optional[IRType]
2308+
) -> IRNode:
23062309
assert self.ir_tree_nodes is not None
23072310

23082311
if self._node_index == len(self.ir_tree_nodes):
@@ -2321,7 +2324,7 @@ def _pop_ir_tree_node(self, ir_type: IRTypeName, kwargs: IRKWargsType) -> IRNode
23212324
# (in fact, it is possible that giving up early here results in more time
23222325
# for useful shrinks to run).
23232326
if node.ir_type != ir_type:
2324-
invalid_at = (ir_type, kwargs)
2327+
invalid_at = (ir_type, kwargs, forced)
23252328
self.invalid_at = invalid_at
23262329
self.observer.mark_invalid(invalid_at)
23272330
self.mark_invalid(f"(internal) want a {ir_type} but have a {node.ir_type}")
@@ -2330,7 +2333,7 @@ def _pop_ir_tree_node(self, ir_type: IRTypeName, kwargs: IRKWargsType) -> IRNode
23302333
# that is allowed by the expected kwargs, then we can coerce this node
23312334
# into an aligned one by using its value. It's unclear how useful this is.
23322335
if not ir_value_permitted(node.value, node.ir_type, kwargs):
2333-
invalid_at = (ir_type, kwargs)
2336+
invalid_at = (ir_type, kwargs, forced)
23342337
self.invalid_at = invalid_at
23352338
self.observer.mark_invalid(invalid_at)
23362339
self.mark_invalid(f"(internal) got a {ir_type} but outside the valid range")

hypothesis-python/src/hypothesis/internal/conjecture/datatree.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -555,16 +555,13 @@ def _repr_pretty_(self, p, cycle):
555555
p.text(_node_pretty(ir_type, value, kwargs, forced=i in self.forced))
556556
indent += 2
557557

558-
if isinstance(self.transition, Branch):
558+
with p.indent(indent):
559559
if len(self.values) > 0:
560560
p.break_()
561-
p.pretty(self.transition)
562-
563-
if isinstance(self.transition, (Killed, Conclusion)):
564-
with p.indent(indent):
565-
if len(self.values) > 0:
566-
p.break_()
561+
if self.transition is not None:
567562
p.pretty(self.transition)
563+
else:
564+
p.text("unknown")
568565

569566

570567
class DataTree:
@@ -843,8 +840,8 @@ def simulate_test_function(self, data):
843840
tree. This will likely change in future."""
844841
node = self.root
845842

846-
def draw(ir_type, kwargs, *, forced=None):
847-
if ir_type == "float" and forced is not None:
843+
def draw(ir_type, kwargs, *, forced=None, convert_forced=True):
844+
if ir_type == "float" and forced is not None and convert_forced:
848845
forced = int_to_float(forced)
849846

850847
draw_func = getattr(data, f"draw_{ir_type}")
@@ -869,9 +866,9 @@ def draw(ir_type, kwargs, *, forced=None):
869866
data.conclude_test(t.status, t.interesting_origin)
870867
elif node.transition is None:
871868
if node.invalid_at is not None:
872-
(ir_type, kwargs) = node.invalid_at
869+
(ir_type, kwargs, forced) = node.invalid_at
873870
try:
874-
draw(ir_type, kwargs)
871+
draw(ir_type, kwargs, forced=forced, convert_forced=False)
875872
except StopTest:
876873
if data.invalid_at is not None:
877874
raise
@@ -1021,7 +1018,8 @@ def draw_boolean(
10211018
self.draw_value("boolean", value, was_forced=was_forced, kwargs=kwargs)
10221019

10231020
def mark_invalid(self, invalid_at: InvalidAt) -> None:
1024-
self.__current_node.invalid_at = invalid_at
1021+
if self.__current_node.transition is None:
1022+
self.__current_node.invalid_at = invalid_at
10251023

10261024
def draw_value(
10271025
self,

hypothesis-python/src/hypothesis/internal/conjecture/engine.py

+3
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def __init__(
207207
self.shrinks: int = 0
208208
self.finish_shrinking_deadline: Optional[float] = None
209209
self.call_count: int = 0
210+
self.misaligned_count: int = 0
210211
self.valid_examples: int = 0
211212
self.random: Random = random or Random(getrandbits(128))
212213
self.database_key: Optional[bytes] = database_key
@@ -418,6 +419,8 @@ def test_function(self, data: ConjectureData) -> None:
418419
}
419420
self.stats_per_test_case.append(call_stats)
420421
self._cache(data)
422+
if data.invalid_at is not None: # pragma: no branch # coverage bug?
423+
self.misaligned_count += 1
421424

422425
self.debug_data(data)
423426

hypothesis-python/src/hypothesis/internal/conjecture/shrinker.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def __init__(
306306
# it's time to stop shrinking.
307307
self.max_stall = 200
308308
self.initial_calls = self.engine.call_count
309+
self.initial_misaligned = self.engine.misaligned_count
309310
self.calls_at_last_shrink = self.initial_calls
310311

311312
self.passes_by_name: Dict[str, ShrinkPass] = {}
@@ -383,6 +384,10 @@ def calls(self):
383384
test function."""
384385
return self.engine.call_count
385386

387+
@property
388+
def misaligned(self):
389+
return self.engine.misaligned_count
390+
386391
def check_calls(self):
387392
if self.calls - self.calls_at_last_shrink >= self.max_stall:
388393
raise StopShrinking
@@ -501,13 +506,14 @@ def s(n):
501506

502507
total_deleted = self.initial_size - len(self.shrink_target.buffer)
503508
calls = self.engine.call_count - self.initial_calls
509+
misaligned = self.engine.misaligned_count - self.initial_misaligned
504510

505511
self.debug(
506512
"---------------------\n"
507513
"Shrink pass profiling\n"
508514
"---------------------\n\n"
509515
f"Shrinking made a total of {calls} call{s(calls)} of which "
510-
f"{self.shrinks} shrank. This deleted {total_deleted} bytes out "
516+
f"{self.shrinks} shrank and {misaligned} were misaligned. This deleted {total_deleted} bytes out "
511517
f"of {self.initial_size}."
512518
)
513519
for useful in [True, False]:
@@ -527,16 +533,9 @@ def s(n):
527533
continue
528534

529535
self.debug(
530-
" * %s made %d call%s of which "
531-
"%d shrank, deleting %d byte%s."
532-
% (
533-
p.name,
534-
p.calls,
535-
s(p.calls),
536-
p.shrinks,
537-
p.deletions,
538-
s(p.deletions),
539-
)
536+
f" * {p.name} made {p.calls} call{s(p.calls)} of which "
537+
f"{p.shrinks} shrank and {p.misaligned} were misaligned, "
538+
f"deleting {p.deletions} byte{s(p.deletions)}."
540539
)
541540
self.debug("")
542541
self.explain()
@@ -1321,7 +1320,7 @@ def boost(k):
13211320

13221321
@defines_shrink_pass()
13231322
def lower_blocks_together(self, chooser):
1324-
block = chooser.choose(self.blocks, lambda b: not b.all_zero)
1323+
block = chooser.choose(self.blocks, lambda b: not b.trivial)
13251324

13261325
# Choose the next block to be up to eight blocks onwards. We don't
13271326
# want to go too far (to avoid quadratic time) but it's worth a
@@ -1330,7 +1329,7 @@ def lower_blocks_together(self, chooser):
13301329
next_block = self.blocks[
13311330
chooser.choose(
13321331
range(block.index + 1, min(len(self.blocks), block.index + 9)),
1333-
lambda j: not self.blocks[j].all_zero,
1332+
lambda j: not self.blocks[j].trivial,
13341333
)
13351334
]
13361335

@@ -1623,6 +1622,7 @@ class ShrinkPass:
16231622
last_prefix = attr.ib(default=())
16241623
successes = attr.ib(default=0)
16251624
calls = attr.ib(default=0)
1625+
misaligned = attr.ib(default=0)
16261626
shrinks = attr.ib(default=0)
16271627
deletions = attr.ib(default=0)
16281628

@@ -1633,6 +1633,7 @@ def step(self, *, random_order=False):
16331633

16341634
initial_shrinks = self.shrinker.shrinks
16351635
initial_calls = self.shrinker.calls
1636+
initial_misaligned = self.shrinker.misaligned
16361637
size = len(self.shrinker.shrink_target.buffer)
16371638
self.shrinker.engine.explain_next_call_as(self.name)
16381639

@@ -1648,6 +1649,7 @@ def step(self, *, random_order=False):
16481649
)
16491650
finally:
16501651
self.calls += self.shrinker.calls - initial_calls
1652+
self.misaligned += self.shrinker.misaligned - initial_misaligned
16511653
self.shrinks += self.shrinker.shrinks - initial_shrinks
16521654
self.deletions += size - len(self.shrinker.shrink_target.buffer)
16531655
self.shrinker.engine.clear_call_explanation()

hypothesis-python/tests/conjecture/test_data_tree.py

+43-3
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,10 @@ def test_datatree_repr(bool_kwargs, int_kwargs):
598598
observer.draw_boolean(False, was_forced=True, kwargs=bool_kwargs)
599599
observer.conclude_test(Status.INTERESTING, interesting_origin=origin)
600600

601+
observer = tree.new_observer()
602+
observer.draw_boolean(False, was_forced=False, kwargs=bool_kwargs)
603+
observer.draw_integer(5, was_forced=False, kwargs=int_kwargs)
604+
601605
assert (
602606
pretty.pretty(tree)
603607
== textwrap.dedent(
@@ -610,13 +614,15 @@ def test_datatree_repr(bool_kwargs, int_kwargs):
610614
integer 0 {int_kwargs}
611615
boolean False [forced] {bool_kwargs}
612616
Conclusion (Status.INTERESTING, {origin})
617+
integer 5 {int_kwargs}
618+
unknown
613619
"""
614620
).strip()
615621
)
616622

617623

618-
def _draw(data, node):
619-
return getattr(data, f"draw_{node.ir_type}")(**node.kwargs)
624+
def _draw(data, node, *, forced=None):
625+
return getattr(data, f"draw_{node.ir_type}")(**node.kwargs, forced=forced)
620626

621627

622628
@given(ir_nodes(), ir_nodes())
@@ -635,7 +641,7 @@ def test_misaligned_nodes_after_valid_draw(node, misaligned_node):
635641
tree.simulate_test_function(data)
636642
assert data.status is Status.INVALID
637643

638-
assert data.invalid_at == (node.ir_type, node.kwargs)
644+
assert data.invalid_at == (node.ir_type, node.kwargs, None)
639645

640646

641647
@given(ir_nodes(was_forced=False), ir_nodes(was_forced=False))
@@ -703,3 +709,37 @@ def test_simulate_non_invalid_conclude_is_unseen_behavior(node, misaligned_node)
703709
tree.simulate_test_function(data)
704710

705711
assert data.status is Status.OVERRUN
712+
713+
714+
@given(ir_nodes(), ir_nodes())
715+
@settings(suppress_health_check=[HealthCheck.too_slow])
716+
def test_simulating_inherits_invalid_forced_status(node, misaligned_node):
717+
assume(misaligned_node.ir_type != node.ir_type)
718+
719+
# we have some logic in DataTree.simulate_test_function to "peek ahead" and
720+
# make sure it simulates invalid nodes correctly. But if it does so without
721+
# respecting whether the invalid node was forced or not, and this simulation
722+
# is observed by an observer, this can cause flaky errors later due to a node
723+
# going from unforced to forced.
724+
725+
tree = DataTree()
726+
727+
def test_function(ir_nodes):
728+
data = ConjectureData.for_ir_tree(ir_nodes, observer=tree.new_observer())
729+
_draw(data, node)
730+
_draw(data, node, forced=node.value)
731+
732+
# (1) set up a misaligned node at index 1
733+
with pytest.raises(StopTest):
734+
test_function([node, misaligned_node])
735+
736+
# (2) simulate an aligned tree. the datatree peeks ahead here using invalid_at
737+
# due to (1).
738+
data = ConjectureData.for_ir_tree([node, node], observer=tree.new_observer())
739+
with pytest.raises(PreviouslyUnseenBehaviour):
740+
tree.simulate_test_function(data)
741+
742+
# (3) run the same aligned tree without simulating. this uses the actual test
743+
# function's draw and forced value. This would flaky error if it did not match
744+
# what the datatree peeked ahead with in (2).
745+
test_function([node, node])

hypothesis-python/tests/conjecture/test_engine.py

+9
Original file line numberDiff line numberDiff line change
@@ -1661,3 +1661,12 @@ def test(data):
16611661
runner.tree.simulate_test_function(ConjectureData.for_ir_tree([node_0]))
16621662
runner.cached_test_function_ir([node_0])
16631663
assert runner.call_count == 3
1664+
1665+
1666+
def test_mildly_complicated_strategy():
1667+
# there are some code paths in engine.py that are easily covered by any mildly
1668+
# compliated strategy and aren't worth testing explicitly for. This covers
1669+
# those.
1670+
n = 5
1671+
s = st.lists(st.integers(), min_size=n)
1672+
assert minimal(s, lambda x: sum(x) >= 2 * n) == [0, 0, 0, 0, n * 2]

hypothesis-python/tests/cover/test_simple_collections.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,8 @@ def test_dictionaries_of_fixed_length(n):
122122

123123
@pytest.mark.parametrize("n", range(10))
124124
def test_lists_of_lower_bounded_length(n):
125-
x = minimal(lists(integers(), min_size=n), lambda x: sum(x) >= 2 * n)
126-
assert n <= len(x) <= 2 * n
127-
assert all(t >= 0 for t in x)
128-
assert len(x) == n or all(t > 0 for t in x)
129-
assert sum(x) == 2 * n
125+
l = minimal(lists(integers(), min_size=n), lambda x: sum(x) >= 2 * n)
126+
assert l == [] if n == 0 else [0] * (n - 1) + [n * 2]
130127

131128

132129
@flaky(min_passes=1, max_runs=3)

0 commit comments

Comments
 (0)