Skip to content

Commit bf31b5e

Browse files
committed
improve redistribute_integer_pairs
1 parent 0df70ec commit bf31b5e

File tree

3 files changed

+41
-39
lines changed

3 files changed

+41
-39
lines changed

hypothesis-python/src/hypothesis/internal/conjecture/shrinker.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
ConjectureData,
2424
ConjectureResult,
2525
Status,
26-
bits_to_bytes,
2726
ir_value_equal,
2827
ir_value_key,
2928
ir_value_permitted,
@@ -681,7 +680,7 @@ def greedy_shrink(self):
681680
"reorder_examples",
682681
"minimize_duplicated_nodes",
683682
"minimize_individual_nodes",
684-
"redistribute_block_pairs",
683+
"redistribute_integer_pairs",
685684
"lower_blocks_together",
686685
]
687686
)
@@ -1227,42 +1226,32 @@ def minimize_duplicated_nodes(self, chooser):
12271226
self.minimize_nodes(nodes)
12281227

12291228
@defines_shrink_pass()
1230-
def redistribute_block_pairs(self, chooser):
1229+
def redistribute_integer_pairs(self, chooser):
12311230
"""If there is a sum of generated integers that we need their sum
12321231
to exceed some bound, lowering one of them requires raising the
12331232
other. This pass enables that."""
1233+
# TODO_SHRINK let's extend this to floats as well.
12341234

1235-
node = chooser.choose(
1235+
# look for a pair of nodes (node1, node2) which are both integers and
1236+
# aren't separated by too many other nodes. We'll decrease node1 and
1237+
# increase node2 (note that the other way around doesn't make sense as
1238+
# it's strictly worse in the ordering).
1239+
node1 = chooser.choose(
12361240
self.nodes, lambda node: node.ir_type == "integer" and not node.trivial
12371241
)
1242+
node2 = chooser.choose(
1243+
self.nodes,
1244+
lambda node: node.ir_type == "integer"
1245+
# Note that it's fine for node2 to be trivial, because we're going to
1246+
# explicitly make it *not* trivial by adding to its value.
1247+
and not node.was_forced
1248+
# to avoid quadratic behavior, scan ahead only a small amount for
1249+
# the related node.
1250+
and node1.index < node.index <= node1.index + 4,
1251+
)
12381252

1239-
# The preconditions for this pass are that the two integer draws are only
1240-
# separated by non-integer nodes, and have the same size value in bytes.
1241-
#
1242-
# This isn't particularly principled. For instance, this wouldn't reduce
1243-
# e.g. @given(integers(), integers(), integers()) where the sum property
1244-
# involves the first and last integers.
1245-
#
1246-
# A better approach may be choosing *two* such integer nodes arbitrarily
1247-
# from the list, instead of conditionally scanning forward.
1248-
1249-
for j in range(node.index + 1, len(self.nodes)):
1250-
next_node = self.nodes[j]
1251-
if next_node.ir_type == "integer" and bits_to_bytes(
1252-
node.value.bit_length()
1253-
) == bits_to_bytes(next_node.value.bit_length()):
1254-
break
1255-
else:
1256-
return
1257-
1258-
if next_node.was_forced:
1259-
# avoid modifying a forced node. Note that it's fine for next_node
1260-
# to be trivial, because we're going to explicitly make it *not*
1261-
# trivial by adding to its value.
1262-
return
1263-
1264-
m = node.value
1265-
n = next_node.value
1253+
m = node1.value
1254+
n = node2.value
12661255

12671256
def boost(k):
12681257
if k > m:
@@ -1272,11 +1261,11 @@ def boost(k):
12721261
next_node_value = n + k
12731262

12741263
return self.consider_new_tree(
1275-
self.nodes[: node.index]
1276-
+ [node.copy(with_value=node_value)]
1277-
+ self.nodes[node.index + 1 : next_node.index]
1278-
+ [next_node.copy(with_value=next_node_value)]
1279-
+ self.nodes[next_node.index + 1 :]
1264+
self.nodes[: node1.index]
1265+
+ [node1.copy(with_value=node_value)]
1266+
+ self.nodes[node1.index + 1 : node2.index]
1267+
+ [node2.copy(with_value=next_node_value)]
1268+
+ self.nodes[node2.index + 1 :]
12801269
)
12811270

12821271
find_integer(boost)

hypothesis-python/tests/conjecture/test_shrinker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def shrinker(data):
562562
assert list(shrinker.buffer) == [1, 0] + [0] * n_gap + [0, 1]
563563

564564

565-
def test_redistribute_block_pairs_with_forced_node():
565+
def test_redistribute_integer_pairs_with_forced_node():
566566
@run_to_buffer
567567
def buf(data):
568568
data.draw_integer(0, 100, forced=15)
@@ -576,8 +576,8 @@ def shrinker(data):
576576
if n1 + n2 > 20:
577577
data.mark_interesting()
578578

579-
shrinker.fixate_shrink_passes(["redistribute_block_pairs"])
580-
# redistribute_block_pairs shouldn't try modifying forced nodes while
579+
shrinker.fixate_shrink_passes(["redistribute_integer_pairs"])
580+
# redistribute_integer_pairs shouldn't try modifying forced nodes while
581581
# shrinking. Since the second draw is forced, this isn't possible to shrink
582582
# with just this pass.
583583
assert shrinker.buffer == buf

hypothesis-python/tests/quality/test_shrink_quality.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,19 @@ def test_sum_of_pair():
349349
) == (1, 1000)
350350

351351

352+
def test_sum_of_pair_separated():
353+
@st.composite
354+
def separated_sum(draw):
355+
n1 = draw(st.integers(0, 1000))
356+
draw(st.text())
357+
draw(st.booleans())
358+
draw(st.integers())
359+
n2 = draw(st.integers(0, 1000))
360+
return (n1, n2)
361+
362+
assert minimal(separated_sum(), lambda x: sum(x) > 1000) == (1, 1000)
363+
364+
352365
def test_calculator_benchmark():
353366
"""This test comes from
354367
https://github.com/jlink/shrinking-challenge/blob/main/challenges/calculator.md,

0 commit comments

Comments
 (0)