Skip to content

Commit 2f37b69

Browse files
authored
Merge pull request #4323 from tybug/xdist-shrinking-benchmark
Support xdist in shrinking benchmark (via the filesystem)
2 parents 5827513 + 2fe0538 commit 2f37b69

File tree

3 files changed

+53
-20
lines changed

3 files changed

+53
-20
lines changed

hypothesis-python/benchmark/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ The plotting script (but not collecting benchmark data) requires additional depe
44

55
To run a benchmark:
66

7-
- `pytest tests/ --hypothesis-benchmark-shrinks new --hypothesis-benchmark-output data.json` (starting on the newer version)
8-
- `pytest tests/ --hypothesis-benchmark-shrinks old --hypothesis-benchmark-output data.json` (after switching to the old version)
7+
- `pytest tests/ -n auto --hypothesis-benchmark-shrinks new --hypothesis-benchmark-output data.json` (starting on the newer version)
8+
- `pytest tests/ -n auto --hypothesis-benchmark-shrinks old --hypothesis-benchmark-output data.json` (after switching to the old version)
99
- Use the same `data.json` path, the benchmark will append data. You can append `-k ...` for both commands to subset the benchmark.
1010
- `python benchmark/graph.py data.json shrinking.png`
1111

hypothesis-python/benchmark/graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ def _diff_times(old, new):
7474
for node_id in old_calls:
7575
old = old_calls[node_id]
7676
new = new_calls[node_id]
77-
if set(old) | set(new) == {0} or len(old) != len(new):
77+
if (
78+
set(old) | set(new) == {0}
79+
or len(old) != len(new)
80+
or len(old) == len(new) == 0
81+
):
7882
print(f"skipping {node_id}")
7983
continue
8084

hypothesis-python/tests/conftest.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import gc
1212
import inspect
1313
import json
14+
import os
1415
import random
1516
import sys
1617
import time as time_module
17-
from collections import defaultdict
1818
from functools import wraps
1919
from pathlib import Path
2020

@@ -61,9 +61,6 @@ def pytest_configure(config):
6161
# be enough: https://github.com/pytest-dev/pytest-xdist/issues/271.
6262
# Need a lockfile or equivalent.
6363

64-
assert not hasattr(
65-
config, "workerinput"
66-
), "--hypothesis-benchmark-shrinks does not currently support xdist. Run without -n"
6764
assert config.getoption(
6865
"--hypothesis-benchmark-output"
6966
), "must specify shrinking output file"
@@ -222,27 +219,35 @@ def pytest_runtest_call(item):
222219
)
223220

224221

225-
shrink_calls = defaultdict(list)
226-
shrink_time = defaultdict(list)
227222
timer = time_module.process_time
228223

229224

230-
def _benchmark_shrinks(item):
225+
def _worker_path(session: pytest.Session) -> Path:
226+
return (
227+
Path(session.config.getoption("--hypothesis-benchmark-output")).parent
228+
# https://pytest-xdist.readthedocs.io/en/stable/how-to.html#envvar-PYTEST_XDIST_WORKER
229+
/ f"shrinking_results_{os.environ['PYTEST_XDIST_WORKER']}.json"
230+
)
231+
232+
233+
def _benchmark_shrinks(item: pytest.Function) -> None:
231234
from hypothesis.internal.conjecture.shrinker import Shrinker
232235

233236
# this isn't perfect, but it is cheap!
234237
if "minimal(" not in inspect.getsource(item.function):
235238
pytest.skip("(probably) does not call minimal()")
236239

237240
actual_shrink = Shrinker.shrink
241+
shrink_calls = []
242+
shrink_time = []
238243

239244
def shrink(self, *args, **kwargs):
245+
nonlocal shrink_calls
246+
nonlocal shrink_time
240247
start_t = timer()
241248
result = actual_shrink(self, *args, **kwargs)
242-
# remove leading hypothesis-python/tests/...
243-
nodeid = item.nodeid.rsplit("/", 1)[1]
244-
shrink_calls[nodeid].append(self.engine.call_count - self.initial_calls)
245-
shrink_time[nodeid].append(timer() - start_t)
249+
shrink_calls.append(self.engine.call_count - self.initial_calls)
250+
shrink_time.append(timer() - start_t)
246251
return result
247252

248253
monkeypatch = MonkeyPatch()
@@ -256,15 +261,39 @@ def shrink(self, *args, **kwargs):
256261

257262
monkeypatch.undo()
258263

264+
# remove leading hypothesis-python/tests/...
265+
nodeid = item.nodeid.rsplit("/", 1)[1]
266+
267+
results_p = _worker_path(item.session)
268+
if not results_p.exists():
269+
results_p.write_text(json.dumps({"calls": {}, "time": {}}))
270+
271+
data = json.loads(results_p.read_text())
272+
data["calls"][nodeid] = shrink_calls
273+
data["time"][nodeid] = shrink_time
274+
results_p.write_text(json.dumps(data))
275+
259276

260277
def pytest_sessionfinish(session, exitstatus):
261278
if not (mode := session.config.getoption("--hypothesis-benchmark-shrinks")):
262279
return
263-
p = Path(session.config.getoption("--hypothesis-benchmark-output"))
264-
results = {mode: {"calls": shrink_calls, "time": shrink_time}}
265-
if not p.exists():
266-
p.write_text(json.dumps(results))
280+
# only run on the controller process, not the workers
281+
if hasattr(session.config, "workerinput"):
282+
return
283+
284+
results = {"calls": {}, "time": {}}
285+
output_p = Path(session.config.getoption("--hypothesis-benchmark-output"))
286+
for p in output_p.parent.iterdir():
287+
if p.name.startswith("shrinking_results_"):
288+
worker_results = json.loads(p.read_text())
289+
results["calls"] |= worker_results["calls"]
290+
results["time"] |= worker_results["time"]
291+
p.unlink()
292+
293+
results = {mode: results}
294+
if not output_p.exists():
295+
output_p.write_text(json.dumps(results))
267296
else:
268-
data = json.loads(p.read_text())
297+
data = json.loads(output_p.read_text())
269298
data[mode] = results[mode]
270-
p.write_text(json.dumps(data))
299+
output_p.write_text(json.dumps(data))

0 commit comments

Comments
 (0)