Skip to content

Commit ee7cef4

Browse files
authored
Merge pull request #4086 from tybug/realign-shrinking
'Realign' IR trees via the intermediary buffer
2 parents b38fe7a + c945158 commit ee7cef4

File tree

16 files changed

+141
-385
lines changed

16 files changed

+141
-385
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
RELEASE_TYPE: patch
2+
3+
This patch improves shrinking in cases involving 'slips' from one strategy to another. Highly composite strategies are the most likely to benefit from this change.
4+
5+
This patch also reduces the range of :class:`python:datetime.datetime` generated by :func:`~hypothesis.extra.django.from_model` in order to avoid https://code.djangoproject.com/ticket/35683.

hypothesis-python/benchmark/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import inspect
1212
import json
13+
import time
1314
from collections import defaultdict
1415

1516
import pytest
@@ -19,6 +20,7 @@
1920
# be enough: https://github.com/pytest-dev/pytest-xdist/issues/271. need a lockfile
2021
# or equivalent.
2122
shrink_calls = defaultdict(list)
23+
timer = time.process_time
2224

2325

2426
def pytest_collection_modifyitems(config, items):
@@ -51,8 +53,11 @@ def record_shrink_calls(calls):
5153
old_shrink = Shrinker.shrink
5254

5355
def shrink(self, *args, **kwargs):
56+
t = timer()
5457
v = old_shrink(self, *args, **kwargs)
55-
record_shrink_calls(self.engine.call_count - self.initial_calls)
58+
time = timer() - t
59+
calls = self.engine.call_count - self.initial_calls
60+
record_shrink_calls({"calls": calls, "time": time})
5661
return v
5762

5863
monkeypatch.setattr(Shrinker, "shrink", shrink)

hypothesis-python/benchmark/graph.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,17 @@
5151
new_names.append(name)
5252
names = new_names
5353

54-
54+
# either "time" or "calls"
55+
statistic = "time"
5556
# name : average calls
5657
old_values = {}
5758
new_values = {}
5859
for name in names:
5960

6061
# mean across the different minimal() calls in a single test function, then
6162
# median across the n iterations we ran that for to reduce error
62-
old_vals = [statistics.mean(run[name]) for run in old_runs]
63-
new_vals = [statistics.mean(run[name]) for run in new_runs]
63+
old_vals = [statistics.mean(r[statistic] for r in run[name]) for run in old_runs]
64+
new_vals = [statistics.mean(r[statistic] for r in run[name]) for run in new_runs]
6465
old_values[name] = statistics.median(old_vals)
6566
new_values[name] = statistics.median(new_vals)
6667

@@ -70,20 +71,21 @@
7071
old = old_values[name]
7172
new = new_values[name]
7273
diff = old - new
73-
diff_times = (old - new) / old
74+
if old == 0:
75+
diff_times = 0
76+
else:
77+
diff_times = (old - new) / old
7478
if 0 < diff_times < 1:
7579
diff_times = (1 / (1 - diff_times)) - 1
7680
diffs[name] = (diff, diff_times)
7781

78-
print(f"{name} {int(diff)} ({int(old)} -> {int(new)}, {round(diff_times, 1)}✕)")
82+
print(f"{name} {diff} ({old} -> {new}, {round(diff_times, 1)}✕)")
7983

8084
diffs = dict(sorted(diffs.items(), key=lambda kv: kv[1][0]))
8185
diffs_value = [v[0] for v in diffs.values()]
8286
diffs_percentage = [v[1] for v in diffs.values()]
8387

84-
print(
85-
f"mean: {int(statistics.mean(diffs_value))}, median: {int(statistics.median(diffs_value))}"
86-
)
88+
print(f"mean: {statistics.mean(diffs_value)}, median: {statistics.median(diffs_value)}")
8789

8890

8991
# https://stackoverflow.com/a/65824524
@@ -100,15 +102,20 @@ def align_axes(ax1, ax2):
100102
ax1.set_ylim(bottom=ax1_ylims[1] * ax2_yratio)
101103

102104

103-
ax1 = sns.barplot(diffs_value, color="b", alpha=0.7, label="shrink call change")
105+
ax1 = sns.barplot(diffs_value, color="b", alpha=0.7, label="absolute change")
104106
ax2 = plt.twinx()
105-
sns.barplot(diffs_percentage, color="r", alpha=0.7, label=r"n✕ change", ax=ax2)
107+
sns.barplot(diffs_percentage, color="r", alpha=0.7, ax=ax2, label="n✕ change")
106108

107-
ax1.set_title("old shrinks - new shrinks (aka shrinks saved, higher is better)")
109+
ax1.set_title(
110+
"old shrinks - new shrinks (aka shrinks saved, higher is better)"
111+
if statistic == "calls"
112+
else "old time - new time in seconds (aka time saved, higher is better)"
113+
)
108114
ax1.set_xticks([])
109115
align_axes(ax1, ax2)
110-
legend = ax1.legend(labels=["shrink call change", "n✕ change"])
111-
legend.legend_handles[0].set_color("b")
112-
legend.legend_handles[1].set_color("r")
116+
legend1 = ax1.legend(loc="upper left")
117+
legend1.legend_handles[0].set_color("b")
118+
legend2 = ax2.legend(loc="lower right")
119+
legend2.legend_handles[0].set_color("r")
113120

114121
plt.show()

hypothesis-python/src/hypothesis/extra/django/_fields.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import re
1212
import string
13-
from datetime import timedelta
13+
from datetime import datetime, timedelta
1414
from decimal import Decimal
1515
from functools import lru_cache
1616
from typing import Any, Callable, Dict, Type, TypeVar, Union
@@ -115,7 +115,12 @@ def inner(func):
115115
@register_for(df.DateTimeField)
116116
def _for_datetime(field):
117117
if getattr(django.conf.settings, "USE_TZ", False):
118-
return st.datetimes(timezones=timezones())
118+
# avoid https://code.djangoproject.com/ticket/35683
119+
return st.datetimes(
120+
min_value=datetime.min + timedelta(days=1),
121+
max_value=datetime.max - timedelta(days=1),
122+
timezones=timezones(),
123+
)
119124
return st.datetimes()
120125

121126

hypothesis-python/src/hypothesis/internal/conjecture/data.py

Lines changed: 84 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ class BooleanKWargs(TypedDict):
127127
IntegerKWargs, FloatKWargs, StringKWargs, BytesKWargs, BooleanKWargs
128128
]
129129
IRTypeName: TypeAlias = Literal["integer", "string", "boolean", "float", "bytes"]
130-
# ir_type, kwargs, forced
131-
InvalidAt: TypeAlias = Tuple[IRTypeName, IRKWargsType, Optional[IRType]]
130+
# index, ir_type, kwargs, forced
131+
MisalignedAt: TypeAlias = Tuple[int, IRTypeName, IRKWargsType, Optional[IRType]]
132132

133133

134134
class ExtraInformation:
@@ -954,9 +954,6 @@ def draw_boolean(
954954
) -> None:
955955
pass
956956

957-
def mark_invalid(self, invalid_at: InvalidAt) -> None:
958-
pass
959-
960957

961958
@attr.s(slots=True, repr=False, eq=False)
962959
class IRNode:
@@ -1169,7 +1166,7 @@ class ConjectureResult:
11691166
examples: Examples = attr.ib(repr=False, eq=False)
11701167
arg_slices: Set[Tuple[int, int]] = attr.ib(repr=False)
11711168
slice_comments: Dict[Tuple[int, int], str] = attr.ib(repr=False)
1172-
invalid_at: Optional[InvalidAt] = attr.ib(repr=False)
1169+
misaligned_at: Optional[MisalignedAt] = attr.ib(repr=False)
11731170

11741171
index: int = attr.ib(init=False)
11751172

@@ -2060,7 +2057,7 @@ def __init__(
20602057
self.extra_information = ExtraInformation()
20612058

20622059
self.ir_tree_nodes = ir_tree_prefix
2063-
self.invalid_at: Optional[InvalidAt] = None
2060+
self.misaligned_at: Optional[MisalignedAt] = None
20642061
self._node_index = 0
20652062
self.start_example(TOP_LABEL)
20662063

@@ -2144,10 +2141,10 @@ def draw_integer(
21442141
)
21452142

21462143
if self.ir_tree_nodes is not None and observe:
2147-
node = self._pop_ir_tree_node("integer", kwargs, forced=forced)
2144+
node_value = self._pop_ir_tree_node("integer", kwargs, forced=forced)
21482145
if forced is None:
2149-
assert isinstance(node.value, int)
2150-
forced = node.value
2146+
assert isinstance(node_value, int)
2147+
forced = node_value
21512148
fake_forced = True
21522149

21532150
value = self.provider.draw_integer(
@@ -2201,10 +2198,10 @@ def draw_float(
22012198
)
22022199

22032200
if self.ir_tree_nodes is not None and observe:
2204-
node = self._pop_ir_tree_node("float", kwargs, forced=forced)
2201+
node_value = self._pop_ir_tree_node("float", kwargs, forced=forced)
22052202
if forced is None:
2206-
assert isinstance(node.value, float)
2207-
forced = node.value
2203+
assert isinstance(node_value, float)
2204+
forced = node_value
22082205
fake_forced = True
22092206

22102207
value = self.provider.draw_float(
@@ -2243,10 +2240,10 @@ def draw_string(
22432240
},
22442241
)
22452242
if self.ir_tree_nodes is not None and observe:
2246-
node = self._pop_ir_tree_node("string", kwargs, forced=forced)
2243+
node_value = self._pop_ir_tree_node("string", kwargs, forced=forced)
22472244
if forced is None:
2248-
assert isinstance(node.value, str)
2249-
forced = node.value
2245+
assert isinstance(node_value, str)
2246+
forced = node_value
22502247
fake_forced = True
22512248

22522249
value = self.provider.draw_string(
@@ -2279,10 +2276,10 @@ def draw_bytes(
22792276
kwargs: BytesKWargs = self._pooled_kwargs("bytes", {"size": size})
22802277

22812278
if self.ir_tree_nodes is not None and observe:
2282-
node = self._pop_ir_tree_node("bytes", kwargs, forced=forced)
2279+
node_value = self._pop_ir_tree_node("bytes", kwargs, forced=forced)
22832280
if forced is None:
2284-
assert isinstance(node.value, bytes)
2285-
forced = node.value
2281+
assert isinstance(node_value, bytes)
2282+
forced = node_value
22862283
fake_forced = True
22872284

22882285
value = self.provider.draw_bytes(
@@ -2320,10 +2317,10 @@ def draw_boolean(
23202317
kwargs: BooleanKWargs = self._pooled_kwargs("boolean", {"p": p})
23212318

23222319
if self.ir_tree_nodes is not None and observe:
2323-
node = self._pop_ir_tree_node("boolean", kwargs, forced=forced)
2320+
node_value = self._pop_ir_tree_node("boolean", kwargs, forced=forced)
23242321
if forced is None:
2325-
assert isinstance(node.value, bool)
2326-
forced = node.value
2322+
assert isinstance(node_value, bool)
2323+
forced = node_value
23272324
fake_forced = True
23282325

23292326
value = self.provider.draw_boolean(
@@ -2367,41 +2364,57 @@ def _pooled_kwargs(self, ir_type, kwargs):
23672364

23682365
def _pop_ir_tree_node(
23692366
self, ir_type: IRTypeName, kwargs: IRKWargsType, *, forced: Optional[IRType]
2370-
) -> IRNode:
2367+
) -> IRType:
2368+
from hypothesis.internal.conjecture.engine import BUFFER_SIZE
2369+
23712370
assert self.ir_tree_nodes is not None
23722371

23732372
if self._node_index == len(self.ir_tree_nodes):
23742373
self.mark_overrun()
23752374

23762375
node = self.ir_tree_nodes[self._node_index]
2377-
# If we're trying to draw a different ir type at the same location, then
2378-
# this ir tree has become badly misaligned. We don't have many good/simple
2379-
# options here for realigning beyond giving up.
2376+
value = node.value
2377+
# If we're trying to:
2378+
# * draw a different ir type at the same location
2379+
# * draw the same ir type with a different kwargs
2380+
#
2381+
# then we call this a misalignment, because the choice sequence has
2382+
# slipped from what we expected at some point. An easy misalignment is
2383+
#
2384+
# st.one_of(st.integers(0, 100), st.integers(101, 200))
23802385
#
2381-
# This is more of an issue for ir nodes while shrinking than it was for
2382-
# buffers: misaligned buffers are still usually valid, just interpreted
2383-
# differently. This would be somewhat like drawing a random value for
2384-
# the new ir type here. For what it's worth, misaligned buffers are
2385-
# rather unlikely to be *useful* buffers, so giving up isn't a big downgrade.
2386-
# (in fact, it is possible that giving up early here results in more time
2387-
# for useful shrinks to run).
2388-
if node.ir_type != ir_type:
2389-
invalid_at = (ir_type, kwargs, forced)
2390-
self.invalid_at = invalid_at
2391-
self.observer.mark_invalid(invalid_at)
2392-
self.mark_invalid(f"(internal) want a {ir_type} but have a {node.ir_type}")
2393-
2394-
# if a node has different kwargs (and so is misaligned), but has a value
2395-
# that is allowed by the expected kwargs, then we can coerce this node
2396-
# into an aligned one by using its value. It's unclear how useful this is.
2397-
if not ir_value_permitted(node.value, node.ir_type, kwargs):
2398-
invalid_at = (ir_type, kwargs, forced)
2399-
self.invalid_at = invalid_at
2400-
self.observer.mark_invalid(invalid_at)
2401-
self.mark_invalid(f"(internal) got a {ir_type} but outside the valid range")
2386+
# where the choice sequence [0, 100] has kwargs {min_value: 0, max_value: 100}
2387+
# at position 2, but [0, 101] has kwargs {min_value: 101, max_value: 200} at
2388+
# position 2.
2389+
#
2390+
# When we see a misalignment, we can't offer up the stored node value as-is.
2391+
# We need to make it appropriate for the requested kwargs and ir type.
2392+
# Right now we do that by using bytes as the intermediary to convert between
2393+
# ir types/kwargs. In the future we'll probably use the index into a custom
2394+
# ordering for an (ir_type, kwargs) pair.
2395+
if node.ir_type != ir_type or not ir_value_permitted(
2396+
node.value, node.ir_type, kwargs
2397+
):
2398+
# only track first misalignment for now.
2399+
if self.misaligned_at is None:
2400+
self.misaligned_at = (self._node_index, ir_type, kwargs, forced)
2401+
(_value, buffer) = ir_to_buffer(
2402+
node.ir_type, node.kwargs, forced=node.value
2403+
)
2404+
try:
2405+
value = buffer_to_ir(
2406+
ir_type, kwargs, buffer=buffer + bytes(BUFFER_SIZE - len(buffer))
2407+
)
2408+
except StopTest:
2409+
# must have been an overrun.
2410+
#
2411+
# maybe we should fall back to to an arbitrary small value here
2412+
# instead? eg
2413+
# buffer_to_ir(ir_type, kwargs, buffer=bytes(BUFFER_SIZE))
2414+
self.mark_overrun()
24022415

24032416
self._node_index += 1
2404-
return node
2417+
return value
24052418

24062419
def as_result(self) -> Union[ConjectureResult, _Overrun]:
24072420
"""Convert the result of running this test into
@@ -2429,7 +2442,7 @@ def as_result(self) -> Union[ConjectureResult, _Overrun]:
24292442
forced_indices=frozenset(self.forced_indices),
24302443
arg_slices=self.arg_slices,
24312444
slice_comments=self.slice_comments,
2432-
invalid_at=self.invalid_at,
2445+
misaligned_at=self.misaligned_at,
24332446
)
24342447
assert self.__result is not None
24352448
self.blocks.transfer_ownership(self.__result)
@@ -2578,38 +2591,9 @@ def freeze(self) -> None:
25782591
self.stop_example()
25792592

25802593
self.__example_record.freeze()
2581-
25822594
self.frozen = True
2583-
25842595
self.buffer = bytes(self.buffer)
2585-
2586-
# if we were invalid because of a misalignment in the tree, we don't
2587-
# want to tell the DataTree that. Doing so would lead to inconsistent behavior.
2588-
# Given an empty DataTree
2589-
# ┌──────┐
2590-
# │ root │
2591-
# └──────┘
2592-
# and supposing the very first draw is misaligned, concluding here would
2593-
# tell the datatree that the *only* possibility at the root node is Status.INVALID:
2594-
# ┌──────┐
2595-
# │ root │
2596-
# └──┬───┘
2597-
# ┌───────────┴───────────────┐
2598-
# │ Conclusion(Status.INVALID)│
2599-
# └───────────────────────────┘
2600-
# when in fact this is only the case when we try to draw a misaligned node.
2601-
# For instance, suppose we come along in the second test case and try a
2602-
# valid node as the first draw from the root. The DataTree thinks this
2603-
# is flaky (because root must lead to Status.INVALID in the tree) while
2604-
# in fact nothing in the test function has changed and the only change
2605-
# is in the ir tree prefix we are supplying.
2606-
#
2607-
# From the perspective of DataTree, it is safe to not conclude here. This
2608-
# tells the datatree that we don't know what happens after this node - which
2609-
# is true! We are aborting early here because the ir tree became misaligned,
2610-
# which is a semantically different invalidity than an assume or filter failing.
2611-
if self.invalid_at is None:
2612-
self.observer.conclude_test(self.status, self.interesting_origin)
2596+
self.observer.conclude_test(self.status, self.interesting_origin)
26132597

26142598
def choice(
26152599
self,
@@ -2716,3 +2700,24 @@ def bits_to_bytes(n: int) -> int:
27162700
Equivalent to (n + 7) // 8, but slightly faster. This really is
27172701
called enough times that that matters."""
27182702
return (n + 7) >> 3
2703+
2704+
2705+
def ir_to_buffer(ir_type, kwargs, *, forced=None, random=None):
2706+
from hypothesis.internal.conjecture.engine import BUFFER_SIZE
2707+
2708+
if forced is None:
2709+
assert random is not None
2710+
2711+
cd = ConjectureData(
2712+
max_length=BUFFER_SIZE,
2713+
# buffer doesn't matter if forced is passed since we're forcing the sole draw
2714+
prefix=b"" if forced is None else bytes(BUFFER_SIZE),
2715+
random=random,
2716+
)
2717+
value = getattr(cd.provider, f"draw_{ir_type}")(**kwargs, forced=forced)
2718+
return (value, cd.buffer)
2719+
2720+
2721+
def buffer_to_ir(ir_type, kwargs, *, buffer):
2722+
cd = ConjectureData.for_buffer(buffer)
2723+
return getattr(cd.provider, f"draw_{ir_type}")(**kwargs)

0 commit comments

Comments
 (0)