Skip to content

Commit 7145c74

Browse files
authored
Merge pull request #4138 from tybug/integer-weights-simple
Fold integer endpoint upweighting into `weights=`
2 parents 760373e + 84dbaee commit 7145c74

File tree

12 files changed

+146
-147
lines changed

12 files changed

+146
-147
lines changed

hypothesis-python/RELEASE.rst

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
RELEASE_TYPE: patch
2+
3+
This release improves integer shrinking by folding the endpoint upweighting for :func:`~hypothesis.strategies.integers` into the ``weights`` parameter of our IR (:issue:`3921`).
4+
5+
If you maintain an alternative backend as part of our (for now explicitly unstable) :ref:`alternative-backends`, this release changes the type of the ``weights`` parameter to ``draw_integer`` and may be a breaking change for you.

hypothesis-python/src/hypothesis/internal/conjecture/data.py

+37-36
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def wrapper(tp):
8888
class IntegerKWargs(TypedDict):
8989
min_value: Optional[int]
9090
max_value: Optional[int]
91-
weights: Optional[Sequence[float]]
91+
weights: Optional[dict[int, float]]
9292
shrink_towards: int
9393

9494

@@ -1287,7 +1287,7 @@ def draw_integer(
12871287
max_value: Optional[int] = None,
12881288
*,
12891289
# weights are for choosing an element index from a bounded range
1290-
weights: Optional[Sequence[float]] = None,
1290+
weights: Optional[dict[int, float]] = None,
12911291
shrink_towards: int = 0,
12921292
forced: Optional[int] = None,
12931293
fake_forced: bool = False,
@@ -1456,8 +1456,7 @@ def draw_integer(
14561456
min_value: Optional[int] = None,
14571457
max_value: Optional[int] = None,
14581458
*,
1459-
# weights are for choosing an element index from a bounded range
1460-
weights: Optional[Sequence[float]] = None,
1459+
weights: Optional[dict[int, float]] = None,
14611460
shrink_towards: int = 0,
14621461
forced: Optional[int] = None,
14631462
fake_forced: bool = False,
@@ -1475,22 +1474,31 @@ def draw_integer(
14751474
assert min_value is not None
14761475
assert max_value is not None
14771476

1478-
sampler = Sampler(weights, observe=False)
1479-
gap = max_value - shrink_towards
1480-
1481-
forced_idx = None
1482-
if forced is not None:
1483-
if forced >= shrink_towards:
1484-
forced_idx = forced - shrink_towards
1485-
else:
1486-
forced_idx = shrink_towards + gap - forced
1487-
idx = sampler.sample(self._cd, forced=forced_idx, fake_forced=fake_forced)
1477+
# format of weights is a mapping of ints to p, where sum(p) < 1.
1478+
# The remaining probability mass is uniformly distributed over
1479+
# *all* ints (not just the unmapped ones; this is somewhat undesirable,
1480+
# but simplifies things).
1481+
#
1482+
# We assert that sum(p) is strictly less than 1 because it simplifies
1483+
# handling forced values when we can force into the unmapped probability
1484+
# mass. We should eventually remove this restriction.
1485+
sampler = Sampler(
1486+
[1 - sum(weights.values()), *weights.values()], observe=False
1487+
)
1488+
# if we're forcing, it's easiest to force into the unmapped probability
1489+
# mass and then force the drawn value after.
1490+
idx = sampler.sample(
1491+
self._cd, forced=None if forced is None else 0, fake_forced=fake_forced
1492+
)
14881493

1489-
# For range -2..2, interpret idx = 0..4 as [0, 1, 2, -1, -2]
1490-
if idx <= gap:
1491-
return shrink_towards + idx
1492-
else:
1493-
return shrink_towards - (idx - gap)
1494+
return self._draw_bounded_integer(
1495+
min_value,
1496+
max_value,
1497+
# implicit reliance on dicts being sorted for determinism
1498+
forced=forced if idx == 0 else list(weights)[idx - 1],
1499+
center=shrink_towards,
1500+
fake_forced=fake_forced,
1501+
)
14941502

14951503
if min_value is None and max_value is None:
14961504
return self._draw_unbounded_integer(forced=forced, fake_forced=fake_forced)
@@ -2116,8 +2124,7 @@ def draw_integer(
21162124
min_value: Optional[int] = None,
21172125
max_value: Optional[int] = None,
21182126
*,
2119-
# weights are for choosing an element index from a bounded range
2120-
weights: Optional[Sequence[float]] = None,
2127+
weights: Optional[dict[int, float]] = None,
21212128
shrink_towards: int = 0,
21222129
forced: Optional[int] = None,
21232130
fake_forced: bool = False,
@@ -2127,9 +2134,14 @@ def draw_integer(
21272134
if weights is not None:
21282135
assert min_value is not None
21292136
assert max_value is not None
2130-
width = max_value - min_value + 1
2131-
assert width <= 255 # arbitrary practical limit
2132-
assert len(weights) == width
2137+
assert len(weights) <= 255 # arbitrary practical limit
2138+
# We can and should eventually support total weights. But this
2139+
# complicates shrinking as we can no longer assume we can force
2140+
# a value to the unmapped probability mass if that mass might be 0.
2141+
assert sum(weights.values()) < 1
2142+
# similarly, things get simpler if we assume every value is possible.
2143+
# we'll want to drop this restriction eventually.
2144+
assert all(w != 0 for w in weights.values())
21332145

21342146
if forced is not None and (min_value is None or max_value is None):
21352147
# We draw `forced=forced - shrink_towards` here internally, after clamping.
@@ -2365,18 +2377,7 @@ def _pooled_kwargs(self, ir_type, kwargs):
23652377
if self.provider.avoid_realization:
23662378
return kwargs
23672379

2368-
key = []
2369-
for k, v in kwargs.items():
2370-
if ir_type == "float" and k in ["min_value", "max_value"]:
2371-
# handle -0.0 vs 0.0, etc.
2372-
v = float_to_int(v)
2373-
elif ir_type == "integer" and k == "weights":
2374-
# make hashable
2375-
v = v if v is None else tuple(v)
2376-
key.append((k, v))
2377-
2378-
key = (ir_type, *sorted(key))
2379-
2380+
key = (ir_type, *ir_kwargs_key(ir_type, kwargs))
23802381
try:
23812382
return POOLED_KWARGS_CACHE[key]
23822383
except KeyError:

hypothesis-python/src/hypothesis/internal/conjecture/shrinker.py

+25-36
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
ConjectureData,
2424
ConjectureResult,
2525
Status,
26-
bits_to_bytes,
2726
ir_value_equal,
2827
ir_value_key,
2928
ir_value_permitted,
@@ -681,7 +680,7 @@ def greedy_shrink(self):
681680
"reorder_examples",
682681
"minimize_duplicated_nodes",
683682
"minimize_individual_nodes",
684-
"redistribute_block_pairs",
683+
"redistribute_integer_pairs",
685684
"lower_blocks_together",
686685
]
687686
)
@@ -1227,42 +1226,32 @@ def minimize_duplicated_nodes(self, chooser):
12271226
self.minimize_nodes(nodes)
12281227

12291228
@defines_shrink_pass()
1230-
def redistribute_block_pairs(self, chooser):
1229+
def redistribute_integer_pairs(self, chooser):
12311230
"""If there is a sum of generated integers that we need their sum
12321231
to exceed some bound, lowering one of them requires raising the
12331232
other. This pass enables that."""
1233+
# TODO_SHRINK let's extend this to floats as well.
12341234

1235-
node = chooser.choose(
1235+
# look for a pair of nodes (node1, node2) which are both integers and
1236+
# aren't separated by too many other nodes. We'll decrease node1 and
1237+
# increase node2 (note that the other way around doesn't make sense as
1238+
# it's strictly worse in the ordering).
1239+
node1 = chooser.choose(
12361240
self.nodes, lambda node: node.ir_type == "integer" and not node.trivial
12371241
)
1242+
node2 = chooser.choose(
1243+
self.nodes,
1244+
lambda node: node.ir_type == "integer"
1245+
# Note that it's fine for node2 to be trivial, because we're going to
1246+
# explicitly make it *not* trivial by adding to its value.
1247+
and not node.was_forced
1248+
# to avoid quadratic behavior, scan ahead only a small amount for
1249+
# the related node.
1250+
and node1.index < node.index <= node1.index + 4,
1251+
)
12381252

1239-
# The preconditions for this pass are that the two integer draws are only
1240-
# separated by non-integer nodes, and have the same size value in bytes.
1241-
#
1242-
# This isn't particularly principled. For instance, this wouldn't reduce
1243-
# e.g. @given(integers(), integers(), integers()) where the sum property
1244-
# involves the first and last integers.
1245-
#
1246-
# A better approach may be choosing *two* such integer nodes arbitrarily
1247-
# from the list, instead of conditionally scanning forward.
1248-
1249-
for j in range(node.index + 1, len(self.nodes)):
1250-
next_node = self.nodes[j]
1251-
if next_node.ir_type == "integer" and bits_to_bytes(
1252-
node.value.bit_length()
1253-
) == bits_to_bytes(next_node.value.bit_length()):
1254-
break
1255-
else:
1256-
return
1257-
1258-
if next_node.was_forced:
1259-
# avoid modifying a forced node. Note that it's fine for next_node
1260-
# to be trivial, because we're going to explicitly make it *not*
1261-
# trivial by adding to its value.
1262-
return
1263-
1264-
m = node.value
1265-
n = next_node.value
1253+
m = node1.value
1254+
n = node2.value
12661255

12671256
def boost(k):
12681257
if k > m:
@@ -1272,11 +1261,11 @@ def boost(k):
12721261
next_node_value = n + k
12731262

12741263
return self.consider_new_tree(
1275-
self.nodes[: node.index]
1276-
+ [node.copy(with_value=node_value)]
1277-
+ self.nodes[node.index + 1 : next_node.index]
1278-
+ [next_node.copy(with_value=next_node_value)]
1279-
+ self.nodes[next_node.index + 1 :]
1264+
self.nodes[: node1.index]
1265+
+ [node1.copy(with_value=node_value)]
1266+
+ self.nodes[node1.index + 1 : node2.index]
1267+
+ [node2.copy(with_value=next_node_value)]
1268+
+ self.nodes[node2.index + 1 :]
12801269
)
12811270

12821271
find_integer(boost)

hypothesis-python/src/hypothesis/strategies/_internal/numbers.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -66,24 +66,21 @@ def __repr__(self):
6666

6767
def do_draw(self, data):
6868
# For bounded integers, make the bounds and near-bounds more likely.
69-
forced = None
69+
weights = None
7070
if (
7171
self.end is not None
7272
and self.start is not None
7373
and self.end - self.start > 127
7474
):
75-
bits = data.draw_integer(0, 127)
76-
forced = {
77-
122: self.start,
78-
123: self.start,
79-
124: self.end,
80-
125: self.end,
81-
126: self.start + 1,
82-
127: self.end - 1,
83-
}.get(bits)
75+
weights = {
76+
self.start: (2 / 128),
77+
self.start + 1: (1 / 128),
78+
self.end - 1: (1 / 128),
79+
self.end: (2 / 128),
80+
}
8481

8582
return data.draw_integer(
86-
min_value=self.start, max_value=self.end, forced=forced
83+
min_value=self.start, max_value=self.end, weights=weights
8784
)
8885

8986
def filter(self, condition):

hypothesis-python/tests/conjecture/common.py

+15-34
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,9 @@ def integer_kwargs(
144144
draw(st.booleans()) if (use_min_value and use_max_value) else False
145145
)
146146

147-
# this generation is complicated to deal with maintaining any combination of
148-
# the following invariants, depending on which parameters are passed:
149-
#
147+
# Invariants:
150148
# (1) min_value <= forced <= max_value
151-
# (2) max_value - min_value + 1 == len(weights)
149+
# (2) sum(weights.values()) < 1
152150
# (3) len(weights) <= 255
153151

154152
if use_shrink_towards:
@@ -158,39 +156,22 @@ def integer_kwargs(
158156
if use_weights:
159157
assert use_max_value
160158
assert use_min_value
161-
# handle the weights case entirely independently from the non-weights case.
162-
# We'll treat the weights as our "key" draw and base all other draws on that.
163159

164-
# weights doesn't play well with super small floats, so exclude <.01
160+
min_value = draw(st.integers(max_value=forced))
161+
min_val = max(min_value, forced) if forced is not None else min_value
162+
max_value = draw(st.integers(min_value=min_val))
163+
164+
# Sampler doesn't play well with super small floats, so exclude them
165165
weights = draw(
166-
st.lists(st.just(0) | st.floats(0.01, 1), min_size=1, max_size=255)
166+
st.dictionaries(st.integers(), st.floats(0.001, 1), max_size=255)
167167
)
168-
# zero is allowed, but it can't be all zeroes
169-
assume(sum(weights) > 0)
170-
171-
# we additionally pick a central value (if not forced), and then the index
172-
# into the weights at which it can be found - aka the min-value offset.
173-
center = forced if use_forced else draw(st.integers())
174-
min_value = center - draw(st.integers(0, len(weights) - 1))
175-
max_value = min_value + len(weights) - 1
176-
177-
if use_forced:
178-
# can't force a 0-weight index.
179-
# we avoid clamping the returned shrink_towards to maximize
180-
# bug-finding power.
181-
_shrink_towards = clamped_shrink_towards(
182-
{
183-
"shrink_towards": shrink_towards,
184-
"min_value": min_value,
185-
"max_value": max_value,
186-
}
187-
)
188-
forced_idx = (
189-
forced - _shrink_towards
190-
if forced >= _shrink_towards
191-
else max_value - forced
192-
)
193-
assume(weights[forced_idx] > 0)
168+
# invalid to have a weighting that disallows all possibilities
169+
assume(sum(weights.values()) != 0)
170+
target = draw(st.floats(0.001, 0.999))
171+
# re-normalize probabilities to sum to some arbitrary value < 1
172+
weights = {k: v / target for k, v in weights.items()}
173+
# float rounding error can cause this to fail.
174+
assume(sum(weights.values()) == target)
194175
else:
195176
if use_min_value:
196177
min_value = draw(st.integers(max_value=forced))

hypothesis-python/tests/conjecture/test_alt_backend.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import math
1212
import sys
13-
from collections.abc import Sequence
1413
from contextlib import contextmanager
1514
from random import Random
1615
from typing import Optional
@@ -67,8 +66,7 @@ def draw_integer(
6766
min_value: Optional[int] = None,
6867
max_value: Optional[int] = None,
6968
*,
70-
# weights are for choosing an element index from a bounded range
71-
weights: Optional[Sequence[float]] = None,
69+
weights: Optional[dict[int, float]] = None,
7270
shrink_towards: int = 0,
7371
forced: Optional[int] = None,
7472
fake_forced: bool = False,

hypothesis-python/tests/conjecture/test_forced.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_forced_many(data):
6868
"min_value": -1,
6969
"max_value": 1,
7070
"shrink_towards": 1,
71-
"weights": [0.1] * 3,
71+
"weights": {-1: 0.2, 0: 0.2, 1: 0.2},
7272
"forced": 0,
7373
},
7474
)
@@ -80,11 +80,35 @@ def test_forced_many(data):
8080
"min_value": -1,
8181
"max_value": 1,
8282
"shrink_towards": -1,
83-
"weights": [0.1] * 3,
83+
"weights": {-1: 0.2, 0: 0.2, 1: 0.2},
8484
"forced": 0,
8585
},
8686
)
8787
)
88+
@example(
89+
(
90+
"integer",
91+
{
92+
"min_value": 10,
93+
"max_value": 1_000,
94+
"shrink_towards": 17,
95+
"weights": {20: 0.1},
96+
"forced": 15,
97+
},
98+
)
99+
)
100+
@example(
101+
(
102+
"integer",
103+
{
104+
"min_value": -1_000,
105+
"max_value": -10,
106+
"shrink_towards": -17,
107+
"weights": {-20: 0.1},
108+
"forced": -15,
109+
},
110+
)
111+
)
88112
@example(("float", {"forced": 0.0}))
89113
@example(("float", {"forced": -0.0}))
90114
@example(("float", {"forced": 1.0}))

0 commit comments

Comments
 (0)