Skip to content

Commit 822e39d

Browse files
authored
Merge pull request #3962 from tybug/shrinker-ir
Migrate most shrinker functions to the ir
2 parents a30c0ef + 3814b54 commit 822e39d

34 files changed

+970
-1325
lines changed

hypothesis-python/RELEASE.rst

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
RELEASE_TYPE: minor
2+
3+
This release migrates the shrinker to our new internal representation, called the IR layer (:pull:`3962`). This improves the shrinker's performance in the majority of cases. For example, on the Hypothesis test suite, shrinking is a median of 1.38x faster.
4+
5+
It is possible this release regresses performance while shrinking certain strategies. If you encounter strategies which reliably shrink more slowly than they used to (or shrink slowly at all), please open an issue!
6+
7+
You can read more about the IR layer at :issue:`3921`.

hypothesis-python/benchmark/README.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
This directory contains code for benchmarking Hypothesis' shrinking. This was written for [pull/3962](https://github.com/HypothesisWorks/hypothesis/pull/3962) and is a manual process at the moment, though we may eventually integrate it more closely with ci for automated benchmarking.
2+
3+
To run a benchmark:
4+
5+
* Add the contents of `conftest.py` to the bottom of `hypothesis-python/tests/conftest.py`
6+
* In `hypothesis-python/tests/common/debug.py`, change `derandomize=True` to `derandomize=False` (if you are running more than one trial)
7+
* Run the tests: `pytest hypothesis-python/tests/`
8+
* Note that the benchmarking script does not currently support xdist, so do not use `-n 8` or similar.
9+
10+
When pytest finishes the output will contain a dictionary of the benchmarking results. Add that as a new entry in `data.json`. Repeat for however many trials you want; n=5 seems reasonable.
11+
12+
Also repeat for both your baseline ("old") and your comparison ("new") code.
13+
14+
Then run `python graph.py` to generate a graph comparing the old and new results.
+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# This file is part of Hypothesis, which may be found at
2+
# https://github.com/HypothesisWorks/hypothesis/
3+
#
4+
# Copyright the Hypothesis Authors.
5+
# Individual contributors are listed in AUTHORS.rst and the git log.
6+
#
7+
# This Source Code Form is subject to the terms of the Mozilla Public License,
8+
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9+
# obtain one at https://mozilla.org/MPL/2.0/.
10+
11+
import inspect
12+
import json
13+
from collections import defaultdict
14+
15+
import pytest
16+
from _pytest.monkeypatch import MonkeyPatch
17+
18+
# we'd like to support xdist here for parallelism, but a session-scope fixture won't
19+
# be enough: https://github.com/pytest-dev/pytest-xdist/issues/271. need a lockfile
20+
# or equivalent.
21+
shrink_calls = defaultdict(list)
22+
23+
24+
def pytest_collection_modifyitems(config, items):
25+
skip = pytest.mark.skip(reason="Does not call minimal()")
26+
for item in items:
27+
# is this perfect? no. but it is cheap!
28+
if " minimal(" in inspect.getsource(item.obj):
29+
continue
30+
item.add_marker(skip)
31+
32+
33+
@pytest.fixture(scope="function", autouse=True)
34+
def _benchmark_shrinks():
35+
from hypothesis.internal.conjecture.shrinker import Shrinker
36+
37+
monkeypatch = MonkeyPatch()
38+
39+
def record_shrink_calls(calls):
40+
name = None
41+
for frame in inspect.stack():
42+
if frame.function.startswith("test_"):
43+
name = f"{frame.filename.split('/')[-1]}::{frame.function}"
44+
# some minimal calls happen at collection-time outside of a test context
45+
# (maybe something we should fix/look into)
46+
if name is None:
47+
return
48+
49+
shrink_calls[name].append(calls)
50+
51+
old_shrink = Shrinker.shrink
52+
53+
def shrink(self, *args, **kwargs):
54+
v = old_shrink(self, *args, **kwargs)
55+
record_shrink_calls(self.engine.call_count - self.initial_calls)
56+
return v
57+
58+
monkeypatch.setattr(Shrinker, "shrink", shrink)
59+
yield
60+
61+
# start teardown
62+
Shrinker.shrink = old_shrink
63+
64+
65+
def pytest_sessionfinish(session, exitstatus):
66+
print(f"\nshrinker profiling:\n{json.dumps(shrink_calls)}")

hypothesis-python/benchmark/data.json

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"old": [],
3+
"new": []
4+
}

hypothesis-python/benchmark/graph.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# This file is part of Hypothesis, which may be found at
2+
# https://github.com/HypothesisWorks/hypothesis/
3+
#
4+
# Copyright the Hypothesis Authors.
5+
# Individual contributors are listed in AUTHORS.rst and the git log.
6+
#
7+
# This Source Code Form is subject to the terms of the Mozilla Public License,
8+
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9+
# obtain one at https://mozilla.org/MPL/2.0/.
10+
11+
import json
12+
import statistics
13+
from pathlib import Path
14+
15+
import matplotlib.pyplot as plt
16+
import seaborn as sns
17+
18+
data_path = Path(__file__).parent / "data.json"
19+
with open(data_path) as f:
20+
data = json.loads(f.read())
21+
22+
old_runs = data["old"]
23+
new_runs = data["new"]
24+
all_runs = old_runs + new_runs
25+
26+
# every run should involve the same functions
27+
names = set()
28+
for run in all_runs:
29+
names.add(frozenset(run.keys()))
30+
31+
intersection = frozenset.intersection(*names)
32+
diff = frozenset.union(*[intersection.symmetric_difference(n) for n in names])
33+
34+
print(f"skipping these tests which were not present in all runs: {', '.join(diff)}")
35+
names = list(intersection)
36+
37+
# the similar invariant for number of minimal calls per run is not true: functions
38+
# may make a variable number of minimal() calls.
39+
# it would be nice to compare identically just the ones which don't vary, to get
40+
# a very fine grained comparison instead of averaging.
41+
# sizes = []
42+
# for run in all_runs:
43+
# sizes.append(tuple(len(value) for value in run.values()))
44+
# assert len(set(sizes)) == 1
45+
46+
new_names = []
47+
for name in names:
48+
if all(all(x == 0 for x in run[name]) for run in all_runs):
49+
print(f"no shrinks for {name}, skipping")
50+
continue
51+
new_names.append(name)
52+
names = new_names
53+
54+
55+
# name : average calls
56+
old_values = {}
57+
new_values = {}
58+
for name in names:
59+
60+
# mean across the different minimal() calls in a single test function, then
61+
# median across the n iterations we ran that for to reduce error
62+
old_vals = [statistics.mean(run[name]) for run in old_runs]
63+
new_vals = [statistics.mean(run[name]) for run in new_runs]
64+
old_values[name] = statistics.median(old_vals)
65+
new_values[name] = statistics.median(new_vals)
66+
67+
# name : (absolute difference, times difference)
68+
diffs = {}
69+
for name in names:
70+
old = old_values[name]
71+
new = new_values[name]
72+
diff = old - new
73+
diff_times = (old - new) / old
74+
if 0 < diff_times < 1:
75+
diff_times = (1 / (1 - diff_times)) - 1
76+
diffs[name] = (diff, diff_times)
77+
78+
print(f"{name} {int(diff)} ({int(old)} -> {int(new)}, {round(diff_times, 1)}✕)")
79+
80+
diffs = dict(sorted(diffs.items(), key=lambda kv: kv[1][0]))
81+
diffs_value = [v[0] for v in diffs.values()]
82+
diffs_percentage = [v[1] for v in diffs.values()]
83+
84+
print(
85+
f"mean: {int(statistics.mean(diffs_value))}, median: {int(statistics.median(diffs_value))}"
86+
)
87+
88+
89+
# https://stackoverflow.com/a/65824524
90+
def align_axes(ax1, ax2):
91+
ax1_ylims = ax1.axes.get_ylim()
92+
ax1_yratio = ax1_ylims[0] / ax1_ylims[1]
93+
94+
ax2_ylims = ax2.axes.get_ylim()
95+
ax2_yratio = ax2_ylims[0] / ax2_ylims[1]
96+
97+
if ax1_yratio < ax2_yratio:
98+
ax2.set_ylim(bottom=ax2_ylims[1] * ax1_yratio)
99+
else:
100+
ax1.set_ylim(bottom=ax1_ylims[1] * ax2_yratio)
101+
102+
103+
ax1 = sns.barplot(diffs_value, color="b", alpha=0.7, label="shrink call change")
104+
ax2 = plt.twinx()
105+
sns.barplot(diffs_percentage, color="r", alpha=0.7, label=r"n✕ change", ax=ax2)
106+
107+
ax1.set_title("old shrinks - new shrinks (aka shrinks saved, higher is better)")
108+
ax1.set_xticks([])
109+
align_axes(ax1, ax2)
110+
legend = ax1.legend(labels=["shrink call change", "n✕ change"])
111+
legend.legend_handles[0].set_color("b")
112+
legend.legend_handles[1].set_color("r")
113+
114+
plt.show()

hypothesis-python/src/hypothesis/internal/conjecture/data.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -779,10 +779,6 @@ def end(self, i: int) -> int:
779779
"""Equivalent to self[i].end."""
780780
return self.endpoints[i]
781781

782-
def bounds(self, i: int) -> Tuple[int, int]:
783-
"""Equivalent to self[i].bounds."""
784-
return (self.start(i), self.end(i))
785-
786782
def all_bounds(self) -> Iterable[Tuple[int, int]]:
787783
"""Equivalent to [(b.start, b.end) for b in self]."""
788784
prev = 0
@@ -970,7 +966,12 @@ class IRNode:
970966
was_forced: bool = attr.ib()
971967
index: Optional[int] = attr.ib(default=None)
972968

973-
def copy(self, *, with_value: IRType) -> "IRNode":
969+
def copy(
970+
self,
971+
*,
972+
with_value: Optional[IRType] = None,
973+
with_kwargs: Optional[IRKWargsType] = None,
974+
) -> "IRNode":
974975
# we may want to allow this combination in the future, but for now it's
975976
# a footgun.
976977
assert not self.was_forced, "modifying a forced node doesn't make sense"
@@ -979,8 +980,8 @@ def copy(self, *, with_value: IRType) -> "IRNode":
979980
# after copying.
980981
return IRNode(
981982
ir_type=self.ir_type,
982-
value=with_value,
983-
kwargs=self.kwargs,
983+
value=self.value if with_value is None else with_value,
984+
kwargs=self.kwargs if with_kwargs is None else with_kwargs,
984985
was_forced=self.was_forced,
985986
)
986987

@@ -1071,9 +1072,17 @@ def __repr__(self):
10711072

10721073
def ir_value_permitted(value, ir_type, kwargs):
10731074
if ir_type == "integer":
1074-
if kwargs["min_value"] is not None and value < kwargs["min_value"]:
1075+
min_value = kwargs["min_value"]
1076+
max_value = kwargs["max_value"]
1077+
shrink_towards = kwargs["shrink_towards"]
1078+
if min_value is not None and value < min_value:
10751079
return False
1076-
if kwargs["max_value"] is not None and value > kwargs["max_value"]:
1080+
if max_value is not None and value > max_value:
1081+
return False
1082+
1083+
if (max_value is None or min_value is None) and (
1084+
value - shrink_towards
1085+
).bit_length() >= 128:
10771086
return False
10781087

10791088
return True
@@ -1144,14 +1153,22 @@ class ConjectureResult:
11441153
status: Status = attr.ib()
11451154
interesting_origin: Optional[InterestingOrigin] = attr.ib()
11461155
buffer: bytes = attr.ib()
1147-
blocks: Blocks = attr.ib()
1156+
# some ConjectureDatas pass through the ir and some pass through buffers.
1157+
# the ir does not drive its result through the buffer, which means blocks/examples
1158+
# may differ (I think for forced values?) even when the buffer is the same.
1159+
# I don't *think* anything was relying on anything but .buffer for result equality,
1160+
# though that assumption may be leaning on flakiness detection invariants.
1161+
#
1162+
# If we consider blocks or examples in equality checks, multiple semantically equal
1163+
# results get stored in e.g. the pareto front.
1164+
blocks: Blocks = attr.ib(eq=False)
11481165
output: str = attr.ib()
11491166
extra_information: Optional[ExtraInformation] = attr.ib()
11501167
has_discards: bool = attr.ib()
11511168
target_observations: TargetObservations = attr.ib()
11521169
tags: FrozenSet[StructuralCoverageTag] = attr.ib()
11531170
forced_indices: FrozenSet[int] = attr.ib(repr=False)
1154-
examples: Examples = attr.ib(repr=False)
1171+
examples: Examples = attr.ib(repr=False, eq=False)
11551172
arg_slices: Set[Tuple[int, int]] = attr.ib(repr=False)
11561173
slice_comments: Dict[Tuple[int, int], str] = attr.ib(repr=False)
11571174
invalid_at: Optional[InvalidAt] = attr.ib(repr=False)

hypothesis-python/src/hypothesis/internal/conjecture/engine.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def _cache_key_ir(
342342
for node in nodes + extension
343343
)
344344

345-
def _cache(self, data: Union[ConjectureData, ConjectureResult]) -> None:
345+
def _cache(self, data: ConjectureData) -> None:
346346
result = data.as_result()
347347
# when we shrink, we try out of bounds things, which can lead to the same
348348
# data.buffer having multiple outcomes. eg data.buffer=b'' is Status.OVERRUN
@@ -357,8 +357,25 @@ def _cache(self, data: Union[ConjectureData, ConjectureResult]) -> None:
357357
# write to the buffer cache here as we move more things to the ir cache.
358358
if data.invalid_at is None:
359359
self.__data_cache[data.buffer] = result
360-
key = self._cache_key_ir(data=data)
361-
self.__data_cache_ir[key] = result
360+
361+
# interesting buffer-based data can mislead the shrinker if we cache them.
362+
#
363+
# @given(st.integers())
364+
# def f(n):
365+
# assert n < 100
366+
#
367+
# may generate two counterexamples, n=101 and n=m > 101, in that order,
368+
# where the buffer corresponding to n is large due to eg failed probes.
369+
# We shrink m and eventually try n=101, but it is cached to a large buffer
370+
# and so the best we can do is n=102, a non-ideal shrink.
371+
#
372+
# We can cache ir-based buffers fine, which always correspond to the
373+
# smallest buffer via forced=. The overhead here is small because almost
374+
# all interesting data are ir-based via the shrinker (and that overhead
375+
# will tend towards zero as we move generation to the ir).
376+
if data.ir_tree_nodes is not None or data.status < Status.INTERESTING:
377+
key = self._cache_key_ir(data=data)
378+
self.__data_cache_ir[key] = result
362379

363380
def cached_test_function_ir(
364381
self, nodes: List[IRNode]
@@ -1218,7 +1235,7 @@ def shrink_interesting_examples(self) -> None:
12181235
self.interesting_examples.values(), key=lambda d: sort_key(d.buffer)
12191236
):
12201237
assert prev_data.status == Status.INTERESTING
1221-
data = self.new_conjecture_data_for_buffer(prev_data.buffer)
1238+
data = self.new_conjecture_data_ir(prev_data.examples.ir_tree_nodes)
12221239
self.test_function(data)
12231240
if data.status != Status.INTERESTING:
12241241
self.exit_with(ExitReason.flaky)

0 commit comments

Comments
 (0)