Skip to content

Commit 3dbfae2

Browse files
authored
Merge pull request #4214 from tybug/automate-shrinking-benchmark
Automate shrinking benchmark more
2 parents 2530e74 + 79fef72 commit 3dbfae2

File tree

8 files changed

+604
-194
lines changed

8 files changed

+604
-194
lines changed

hypothesis-python/benchmark/README.md

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
This directory contains code for benchmarking Hypothesis' shrinking. This was written for [pull/3962](https://github.com/HypothesisWorks/hypothesis/pull/3962) and is a manual process at the moment, though we may eventually integrate it more closely with ci for automated benchmarking.
1+
This directory contains plotting code for our shrinker benchmarking. The code for collecting the data is in `conftest.py`. This directory handles plotting the results.
22

3-
To run a benchmark:
4-
5-
* Add the contents of `conftest.py` to the bottom of `hypothesis-python/tests/conftest.py`
6-
* In `hypothesis-python/tests/common/debug.py`, change `derandomize=True` to `derandomize=False` (if you are running more than one trial)
7-
* Run the tests: `pytest hypothesis-python/tests/`
8-
* Note that the benchmarking script does not currently support xdist, so do not use `-n 8` or similar.
3+
The plotting script (but not collecting benchmark data) requires additional dependencies: `pip install scipy vl-convert-python`.
94

10-
When pytest finishes the output will contain a dictionary of the benchmarking results. Add that as a new entry in `data.json`. Repeat for however many trials you want; n=5 seems reasonable.
5+
To run a benchmark:
116

12-
Also repeat for both your baseline ("old") and your comparison ("new") code.
7+
- `pytest tests/ --hypothesis-benchmark-shrinks new --hypothesis-benchmark-output data.json` (starting on the newer version)
8+
- `pytest tests/ --hypothesis-benchmark-shrinks old --hypothesis-benchmark-output data.json` (after switching to the old version)
9+
- Use the same `data.json` path, the benchmark will append data. You can append `-k ...` for both commands to subset the benchmark.
10+
- `python benchmark/graph.py data.json shrinking.png`
1311

14-
Then run `python graph.py` to generate a graph comparing the old and new results.
12+
This hooks any `minimal()` calls any reports the number of shrinks. Default (and currently unchangeable) number of iterations is 5 per test.

hypothesis-python/benchmark/conftest.py

Lines changed: 0 additions & 71 deletions
This file was deleted.

hypothesis-python/benchmark/data.json

Lines changed: 0 additions & 4 deletions
This file was deleted.

hypothesis-python/benchmark/graph.py

Lines changed: 111 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -9,113 +9,117 @@
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

1111
import json
12+
import math
1213
import statistics
1314
from pathlib import Path
1415

15-
import matplotlib.pyplot as plt
16-
import seaborn as sns
17-
18-
data_path = Path(__file__).parent / "data.json"
19-
with open(data_path) as f:
20-
data = json.loads(f.read())
21-
22-
old_runs = data["old"]
23-
new_runs = data["new"]
24-
all_runs = old_runs + new_runs
25-
26-
# every run should involve the same functions
27-
names = set()
28-
for run in all_runs:
29-
names.add(frozenset(run.keys()))
30-
31-
intersection = frozenset.intersection(*names)
32-
diff = frozenset.union(*[intersection.symmetric_difference(n) for n in names])
33-
34-
print(f"skipping these tests which were not present in all runs: {', '.join(diff)}")
35-
names = list(intersection)
36-
37-
# the similar invariant for number of minimal calls per run is not true: functions
38-
# may make a variable number of minimal() calls.
39-
# it would be nice to compare identically just the ones which don't vary, to get
40-
# a very fine grained comparison instead of averaging.
41-
# sizes = []
42-
# for run in all_runs:
43-
# sizes.append(tuple(len(value) for value in run.values()))
44-
# assert len(set(sizes)) == 1
45-
46-
new_names = []
47-
for name in names:
48-
if all(all(x == 0 for x in run[name]) for run in all_runs):
49-
print(f"no shrinks for {name}, skipping")
50-
continue
51-
new_names.append(name)
52-
names = new_names
53-
54-
# either "time" or "calls"
55-
statistic = "time"
56-
# name : average calls
57-
old_values = {}
58-
new_values = {}
59-
for name in names:
60-
61-
# mean across the different minimal() calls in a single test function, then
62-
# median across the n iterations we ran that for to reduce error
63-
old_vals = [statistics.mean(r[statistic] for r in run[name]) for run in old_runs]
64-
new_vals = [statistics.mean(r[statistic] for r in run[name]) for run in new_runs]
65-
old_values[name] = statistics.median(old_vals)
66-
new_values[name] = statistics.median(new_vals)
67-
68-
# name : (absolute difference, times difference)
69-
diffs = {}
70-
for name in names:
71-
old = old_values[name]
72-
new = new_values[name]
73-
diff = old - new
74-
if old == 0:
75-
diff_times = 0
76-
else:
77-
diff_times = (old - new) / old
78-
if 0 < diff_times < 1:
79-
diff_times = (1 / (1 - diff_times)) - 1
80-
diffs[name] = (diff, diff_times)
81-
82-
print(f"{name} {diff} ({old} -> {new}, {round(diff_times, 1)}✕)")
83-
84-
diffs = dict(sorted(diffs.items(), key=lambda kv: kv[1][0]))
85-
diffs_value = [v[0] for v in diffs.values()]
86-
diffs_percentage = [v[1] for v in diffs.values()]
87-
88-
print(f"mean: {statistics.mean(diffs_value)}, median: {statistics.median(diffs_value)}")
89-
90-
91-
# https://stackoverflow.com/a/65824524
92-
def align_axes(ax1, ax2):
93-
ax1_ylims = ax1.axes.get_ylim()
94-
ax1_yratio = ax1_ylims[0] / ax1_ylims[1]
95-
96-
ax2_ylims = ax2.axes.get_ylim()
97-
ax2_yratio = ax2_ylims[0] / ax2_ylims[1]
98-
99-
if ax1_yratio < ax2_yratio:
100-
ax2.set_ylim(bottom=ax2_ylims[1] * ax1_yratio)
101-
else:
102-
ax1.set_ylim(bottom=ax1_ylims[1] * ax2_yratio)
103-
104-
105-
ax1 = sns.barplot(diffs_value, color="b", alpha=0.7, label="absolute change")
106-
ax2 = plt.twinx()
107-
sns.barplot(diffs_percentage, color="r", alpha=0.7, ax=ax2, label="n✕ change")
108-
109-
ax1.set_title(
110-
"old shrinks - new shrinks (aka shrinks saved, higher is better)"
111-
if statistic == "calls"
112-
else "old time - new time in seconds (aka time saved, higher is better)"
113-
)
114-
ax1.set_xticks([])
115-
align_axes(ax1, ax2)
116-
legend1 = ax1.legend(loc="upper left")
117-
legend1.legend_handles[0].set_color("b")
118-
legend2 = ax2.legend(loc="lower right")
119-
legend2.legend_handles[0].set_color("r")
120-
121-
plt.show()
16+
import click
17+
18+
19+
def plot_vega(vega_spec, data, *, to, parameters=None):
20+
import vl_convert
21+
22+
parameters = parameters or {}
23+
24+
spec = json.loads(vega_spec.read_text())
25+
spec["data"].insert(0, {"name": "source", "values": data})
26+
if "signals" not in spec:
27+
spec["signals"] = []
28+
29+
for key, value in parameters.items():
30+
spec["signals"].append({"name": key, "value": value})
31+
32+
with open(to, "wb") as f:
33+
# default ppi is 72, which is somewhat blurry.
34+
f.write(vl_convert.vega_to_png(spec, ppi=200))
35+
36+
37+
def _mean_difference_ci(n1, n2, *, confidence):
38+
from scipy import stats
39+
40+
var1 = statistics.variance(n1)
41+
var2 = statistics.variance(n2)
42+
df = len(n1) + len(n2) - 2
43+
# this assumes equal variances between the populations of n1 and n2. This
44+
# is not necessarily true (new might be more consistent than old), but it's
45+
# good enough.
46+
pooled_std = math.sqrt(((len(n1) - 1) * var1 + (len(n2) - 1) * var2) / df)
47+
se = pooled_std * math.sqrt(1 / len(n1) + 1 / len(n2))
48+
t_crit = stats.t.ppf((1 + confidence) / 2, df)
49+
return t_crit * se
50+
51+
52+
def _process_benchmark_data(data):
53+
assert set(data) == {"old", "new"}
54+
old_calls = data["old"]["calls"]
55+
new_calls = data["new"]["calls"]
56+
assert set(old_calls) == set(new_calls), set(old_calls).symmetric_difference(
57+
set(new_calls)
58+
)
59+
60+
graph_data = []
61+
62+
def _diff_times(old, new):
63+
if old == 0 and new == 0:
64+
return 0
65+
if old == 0:
66+
# there aren't any great options here, but 0 is more reasonable than inf.
67+
return 0
68+
v = (old - new) / old
69+
if 0 < v < 1:
70+
v = (1 / (1 - v)) - 1
71+
return v
72+
73+
sums = {"old": 0, "new": 0}
74+
for node_id in old_calls:
75+
old = old_calls[node_id]
76+
new = new_calls[node_id]
77+
if set(old) | set(new) == {0} or len(old) != len(new):
78+
print(f"skipping {node_id}")
79+
continue
80+
81+
sums["old"] += statistics.mean(old)
82+
sums["new"] += statistics.mean(new)
83+
diffs = [n_old - n_new for n_old, n_new in zip(old, new)]
84+
diffs_times = [_diff_times(n_old, n_new) for n_old, n_new in zip(old, new)]
85+
ci_shrink = (
86+
_mean_difference_ci(old, new, confidence=0.95) if len(old) > 1 else 0
87+
)
88+
89+
graph_data.append(
90+
{
91+
"node_id": node_id,
92+
"absolute": statistics.mean(diffs),
93+
"absolute_ci_lower": ci_shrink,
94+
"absolute_ci_upper": ci_shrink,
95+
"nx": statistics.mean(diffs_times),
96+
"nx_ci_lower": 0,
97+
"nx_ci_upper": 0,
98+
}
99+
)
100+
101+
graph_data = sorted(graph_data, key=lambda d: d["absolute"])
102+
return graph_data, sums
103+
104+
105+
@click.command()
106+
@click.argument("data", type=click.Path(exists=True, path_type=Path))
107+
@click.argument("out", type=click.Path(path_type=Path))
108+
def plot(data, out):
109+
data = json.loads(data.read_text())
110+
data, sums = _process_benchmark_data(data)
111+
plot_vega(
112+
Path(__file__).parent / "spec.json",
113+
data=data,
114+
to=out,
115+
parameters={
116+
"title": "Shrinking benchmark (calls)",
117+
"sum_old": sums["old"],
118+
"sum_new": sums["new"],
119+
"absolute_axis_title": ("shrink call change (old - new, larger is good)"),
120+
},
121+
)
122+
123+
124+
if __name__ == "__main__":
125+
plot()

0 commit comments

Comments
 (0)