Skip to content

Commit d107e19

Browse files
authored
Merge pull request #4299 from tybug/shrink-more-sortkey
Use `shrink_key` to avoid shrinking work in more places
2 parents f62ec1e + c9c9464 commit d107e19

File tree

3 files changed

+53
-40
lines changed

3 files changed

+53
-40
lines changed

hypothesis-python/RELEASE.rst

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
RELEASE_TYPE: patch
2+
3+
Improve how the shrinker checks for unnecessary work, leading to 10% less time spent shrinking on average, with no reduction in quality.

hypothesis-python/src/hypothesis/internal/conjecture/shrinker.py

+36-40
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import math
1212
from collections import defaultdict
1313
from collections.abc import Sequence
14-
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
14+
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast
1515

1616
import attr
1717

@@ -136,7 +136,7 @@ class Shrinker:
136136
manage the associated state of a particular shrink problem. That is, we
137137
have some initial ConjectureData object and some property of interest
138138
that it satisfies, and we want to find a ConjectureData object with a
139-
shortlex (see sort_key above) smaller buffer that exhibits the same
139+
shortlex (see sort_key above) smaller choice sequence that exhibits the same
140140
property.
141141
142142
Currently the only property of interest we use is that the status is
@@ -160,7 +160,7 @@ class Shrinker:
160160
=======================
161161
162162
Generally a shrink pass is just any function that calls
163-
cached_test_function and/or incorporate_new_buffer a number of times,
163+
cached_test_function and/or consider_new_nodes a number of times,
164164
but there are a couple of useful things to bear in mind.
165165
166166
A shrink pass *makes progress* if running it changes self.shrink_target
@@ -202,22 +202,22 @@ class Shrinker:
202202
are carefully designed to do the right thing in the case that no
203203
shrinks occurred and try to adapt to any changes to do a reasonable
204204
job. e.g. say we wanted to write a shrink pass that tried deleting
205-
each individual byte (this isn't an especially good choice,
205+
each individual choice (this isn't an especially good pass,
206206
but it leads to a simple illustrative example), we might do it
207-
by iterating over the buffer like so:
207+
by iterating over the choice sequence like so:
208208
209209
.. code-block:: python
210210
211211
i = 0
212-
while i < len(self.shrink_target.buffer):
213-
if not self.incorporate_new_buffer(
214-
self.shrink_target.buffer[:i] + self.shrink_target.buffer[i + 1 :]
212+
while i < len(self.shrink_target.nodes):
213+
if not self.consider_new_nodes(
214+
self.shrink_target.nodes[:i] + self.shrink_target.nodes[i + 1 :]
215215
):
216216
i += 1
217217
218218
The reason for writing the loop this way is that i is always a
219-
valid index into the current buffer, even if the current buffer
220-
changes as a result of our actions. When the buffer changes,
219+
valid index into the current choice sequence, even if the current sequence
220+
changes as a result of our actions. When the choice sequence changes,
221221
we leave the index where it is rather than restarting from the
222222
beginning, and carry on. This means that the number of steps we
223223
run in this case is always bounded above by the number of steps
@@ -308,10 +308,8 @@ def __init__(
308308
self.__predicate = predicate or (lambda data: True)
309309
self.__allow_transition = allow_transition or (lambda source, destination: True)
310310
self.__derived_values: dict = {}
311-
self.__pending_shrink_explanation = None
312311

313312
self.initial_size = len(initial.choices)
314-
315313
# We keep track of the current best example on the shrink_target
316314
# attribute.
317315
self.shrink_target = initial
@@ -331,7 +329,7 @@ def __init__(
331329

332330
# Because the shrinker is also used to `pareto_optimise` in the target phase,
333331
# we sometimes want to allow extending buffers instead of aborting at the end.
334-
self.__extend = "full" if in_target_phase else 0
332+
self.__extend: Union[Literal["full"], int] = "full" if in_target_phase else 0
335333
self.should_explain = explain
336334

337335
@derived_value # type: ignore
@@ -383,32 +381,32 @@ def check_calls(self) -> None:
383381
if self.calls - self.calls_at_last_shrink >= self.max_stall:
384382
raise StopShrinking
385383

386-
def cached_test_function(self, nodes):
384+
def cached_test_function(
385+
self, nodes: Sequence[ChoiceNode]
386+
) -> tuple[bool, Optional[Union[ConjectureResult, _Overrun]]]:
387+
nodes = nodes[: len(self.nodes)]
388+
389+
if startswith(nodes, self.nodes):
390+
return (True, None)
391+
392+
if sort_key(self.nodes) < sort_key(nodes):
393+
return (False, None)
394+
387395
# sometimes our shrinking passes try obviously invalid things. We handle
388396
# discarding them in one place here.
389-
for node in nodes:
390-
if not choice_permitted(node.value, node.kwargs):
391-
return None
397+
if any(not choice_permitted(node.value, node.kwargs) for node in nodes):
398+
return (False, None)
392399

393400
result = self.engine.cached_test_function(
394401
[n.value for n in nodes], extend=self.__extend
395402
)
403+
previous = self.shrink_target
396404
self.incorporate_test_data(result)
397405
self.check_calls()
398-
return result
406+
return (previous is not self.shrink_target, result)
399407

400408
def consider_new_nodes(self, nodes: Sequence[ChoiceNode]) -> bool:
401-
nodes = nodes[: len(self.nodes)]
402-
403-
if startswith(nodes, self.nodes):
404-
return True
405-
406-
if sort_key(self.nodes) < sort_key(nodes):
407-
return False
408-
409-
previous = self.shrink_target
410-
self.cached_test_function(nodes)
411-
return previous is not self.shrink_target
409+
return self.cached_test_function(nodes)[0]
412410

413411
def incorporate_test_data(self, data):
414412
"""Takes a ConjectureData or Overrun object updates the current
@@ -458,8 +456,8 @@ def s(n):
458456
"Shrink pass profiling\n"
459457
"---------------------\n\n"
460458
f"Shrinking made a total of {calls} call{s(calls)} of which "
461-
f"{self.shrinks} shrank and {misaligned} were misaligned. This deleted {total_deleted} choices out "
462-
f"of {self.initial_size}."
459+
f"{self.shrinks} shrank and {misaligned} were misaligned. This "
460+
f"deleted {total_deleted} choices out of {self.initial_size}."
463461
)
464462
for useful in [True, False]:
465463
self.debug("")
@@ -700,7 +698,7 @@ def reduce_each_alternative(self):
700698
# previous values to no longer be valid in its position.
701699
zero_attempt = self.cached_test_function(
702700
nodes[:i] + (nodes[i].copy(with_value=0),) + nodes[i + 1 :]
703-
)
701+
)[1]
704702
if (
705703
zero_attempt is not self.shrink_target
706704
and zero_attempt is not None
@@ -731,10 +729,9 @@ def try_lower_node_as_alternative(self, i, v):
731729
while rerandomising and attempting to repair any subsequent
732730
changes to the shape of the test case that this causes."""
733731
nodes = self.shrink_target.nodes
734-
initial_attempt = self.cached_test_function(
732+
if self.consider_new_nodes(
735733
nodes[:i] + (nodes[i].copy(with_value=v),) + nodes[i + 1 :]
736-
)
737-
if initial_attempt is self.shrink_target:
734+
):
738735
return True
739736

740737
prefix = nodes[:i] + (nodes[i].copy(with_value=v),)
@@ -1090,7 +1087,7 @@ def try_shrinking_nodes(self, nodes, n):
10901087
[(node.index, node.index + 1, [node.copy(with_value=n)]) for node in nodes],
10911088
)
10921089

1093-
attempt = self.cached_test_function(initial_attempt)
1090+
attempt = self.cached_test_function(initial_attempt)[1]
10941091

10951092
if attempt is None:
10961093
return False
@@ -1149,8 +1146,7 @@ def try_shrinking_nodes(self, nodes, n):
11491146
# attempts which increase min_size tend to overrun rather than
11501147
# be misaligned, making a covering case difficult.
11511148
return False # pragma: no cover
1152-
# the size decreased in our attempt. Try again, but replace with
1153-
# the min_size that we would have gotten, and truncate the value
1149+
# the size decreased in our attempt. Try again, but truncate the value
11541150
# to that size by removing any elements past min_size.
11551151
return self.consider_new_nodes(
11561152
initial_attempt[: node.index]
@@ -1534,7 +1530,7 @@ def try_trivial_spans(self, chooser):
15341530
]
15351531
)
15361532
suffix = nodes[ex.end :]
1537-
attempt = self.cached_test_function(prefix + replacement + suffix)
1533+
attempt = self.cached_test_function(prefix + replacement + suffix)[1]
15381534

15391535
if self.shrink_target is not prev:
15401536
return
@@ -1598,7 +1594,7 @@ def minimize_individual_choices(self, chooser):
15981594
+ (node.copy(with_value=node.value - 1),)
15991595
+ self.nodes[node.index + 1 :]
16001596
)
1601-
attempt = self.cached_test_function(lowered)
1597+
attempt = self.cached_test_function(lowered)[1]
16021598
if (
16031599
attempt is None
16041600
or attempt.status < Status.VALID

hypothesis-python/tests/conjecture/test_shrinker.py

+14
Original file line numberDiff line numberDiff line change
@@ -655,3 +655,17 @@ def shrinker(data: ConjectureData):
655655

656656
shrinker.fixate_shrink_passes(["lower_duplicated_characters"])
657657
assert shrinker.choices == (expected[0],) + (0,) * gap + (expected[1],)
658+
659+
660+
def test_shrinking_one_of_with_same_shape():
661+
# This is a covering test for our one_of shrinking logic for the case when
662+
# the choice sequence *doesn't* change shape in the newly chosen one_of branch.
663+
@shrinking_from([1, 0])
664+
def shrinker(data: ConjectureData):
665+
n = data.draw_integer(0, 1)
666+
data.draw_integer()
667+
if n == 1:
668+
data.mark_interesting()
669+
670+
shrinker.initial_coarse_reduction()
671+
assert shrinker.choices == (1, 0)

0 commit comments

Comments
 (0)