Skip to content

Commit 075a46b

Browse files
committed
mark invalid on forced misalignment, add ir-specific comparison methods
1 parent 9de1810 commit 075a46b

File tree

2 files changed

+134
-16
lines changed

2 files changed

+134
-16
lines changed

hypothesis-python/src/hypothesis/internal/conjecture/data.py

+64-7
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ def draw_boolean(
910910
pass
911911

912912

913-
@attr.s(slots=True, repr=False)
913+
@attr.s(slots=True, repr=False, eq=False)
914914
class IRNode:
915915
ir_type: IRTypeName = attr.ib()
916916
value: IRType = attr.ib()
@@ -928,6 +928,17 @@ def copy(self, *, with_value: IRType) -> "IRNode":
928928
was_forced=self.was_forced,
929929
)
930930

931+
def __eq__(self, other):
932+
if not isinstance(other, IRNode):
933+
return NotImplemented
934+
935+
return (
936+
self.ir_type == other.ir_type
937+
and ir_value_equal(self.ir_type, self.value, other.value)
938+
and ir_kwargs_equal(self.ir_type, self.kwargs, other.kwargs)
939+
and self.was_forced == other.was_forced
940+
)
941+
931942
def __repr__(self):
932943
# repr to avoid "BytesWarning: str() on a bytes instance" for bytes nodes
933944
forced_marker = " [forced]" if self.was_forced else ""
@@ -967,6 +978,24 @@ def ir_value_permitted(value, ir_type, kwargs):
967978
raise NotImplementedError(f"unhandled type {type(value)} of ir value {value}")
968979

969980

981+
def ir_value_equal(ir_type, v1, v2):
982+
if ir_type != "float":
983+
return v1 == v2
984+
return float_to_int(v1) == float_to_int(v2)
985+
986+
987+
def ir_kwargs_equal(ir_type, kwargs1, kwargs2):
988+
if ir_type != "float":
989+
return kwargs1 == kwargs2
990+
return (
991+
float_to_int(kwargs1["min_value"]) == float_to_int(kwargs2["min_value"])
992+
and float_to_int(kwargs1["max_value"]) == float_to_int(kwargs2["max_value"])
993+
and kwargs1["allow_nan"] == kwargs2["allow_nan"]
994+
and kwargs1["smallest_nonzero_magnitude"]
995+
== kwargs2["smallest_nonzero_magnitude"]
996+
)
997+
998+
970999
@dataclass_transform()
9711000
@attr.s(slots=True)
9721001
class ConjectureResult:
@@ -1880,7 +1909,7 @@ def draw_integer(
18801909
)
18811910

18821911
if self.ir_tree_nodes is not None and observe:
1883-
node = self._pop_ir_tree_node("integer", kwargs)
1912+
node = self._pop_ir_tree_node("integer", kwargs, forced=forced)
18841913
assert isinstance(node.value, int)
18851914
forced = node.value
18861915
fake_forced = not node.was_forced
@@ -1936,7 +1965,7 @@ def draw_float(
19361965
)
19371966

19381967
if self.ir_tree_nodes is not None and observe:
1939-
node = self._pop_ir_tree_node("float", kwargs)
1968+
node = self._pop_ir_tree_node("float", kwargs, forced=forced)
19401969
assert isinstance(node.value, float)
19411970
forced = node.value
19421971
fake_forced = not node.was_forced
@@ -1977,7 +2006,7 @@ def draw_string(
19772006
},
19782007
)
19792008
if self.ir_tree_nodes is not None and observe:
1980-
node = self._pop_ir_tree_node("string", kwargs)
2009+
node = self._pop_ir_tree_node("string", kwargs, forced=forced)
19812010
assert isinstance(node.value, str)
19822011
forced = node.value
19832012
fake_forced = not node.was_forced
@@ -2012,7 +2041,7 @@ def draw_bytes(
20122041
kwargs: BytesKWargs = self._pooled_kwargs("bytes", {"size": size})
20132042

20142043
if self.ir_tree_nodes is not None and observe:
2015-
node = self._pop_ir_tree_node("bytes", kwargs)
2044+
node = self._pop_ir_tree_node("bytes", kwargs, forced=forced)
20162045
assert isinstance(node.value, bytes)
20172046
forced = node.value
20182047
fake_forced = not node.was_forced
@@ -2053,7 +2082,7 @@ def draw_boolean(
20532082
kwargs: BooleanKWargs = self._pooled_kwargs("boolean", {"p": p})
20542083

20552084
if self.ir_tree_nodes is not None and observe:
2056-
node = self._pop_ir_tree_node("boolean", kwargs)
2085+
node = self._pop_ir_tree_node("boolean", kwargs, forced=forced)
20572086
assert isinstance(node.value, bool)
20582087
forced = node.value
20592088
fake_forced = not node.was_forced
@@ -2093,7 +2122,9 @@ def _pooled_kwargs(self, ir_type, kwargs):
20932122
POOLED_KWARGS_CACHE[key] = kwargs
20942123
return kwargs
20952124

2096-
def _pop_ir_tree_node(self, ir_type: IRTypeName, kwargs: IRKWargsType) -> IRNode:
2125+
def _pop_ir_tree_node(
2126+
self, ir_type: IRTypeName, kwargs: IRKWargsType, *, forced: Optional[IRType]
2127+
) -> IRNode:
20972128
assert self.ir_tree_nodes is not None
20982129

20992130
if self.ir_tree_nodes == []:
@@ -2120,6 +2151,32 @@ def _pop_ir_tree_node(self, ir_type: IRTypeName, kwargs: IRKWargsType) -> IRNode
21202151
if not ir_value_permitted(node.value, node.ir_type, kwargs):
21212152
self.mark_invalid() # pragma: no cover # FIXME @tybug
21222153

2154+
if forced is not None:
2155+
# if we expected a forced node but are instead returning a non-forced
2156+
# node, something has gone terribly wrong. If we allowed this combination,
2157+
# we risk violating core invariants that rely on forced draws being,
2158+
# well, forced to a particular value.
2159+
#
2160+
# In particular, this can manifest while shrinking. Consider the tree
2161+
# [boolean True [forced] {"p": 0.5}]
2162+
# [boolean False {"p": 0.5}]
2163+
#
2164+
# and the shrinker tries to reorder these to
2165+
# [boolean False {"p": 0.5}]
2166+
# [boolean True [forced] {"p": 0.5}].
2167+
#
2168+
# However, maybe we got lucky and the non-forced node is returning
2169+
# the same value that was expected from the forced draw. We lucked
2170+
# into an aligned tree in this case and can let it slide.
2171+
if not node.was_forced and not ir_value_equal(ir_type, forced, node.value):
2172+
self.mark_invalid()
2173+
2174+
# similarly, if we expected a forced node with a certain value, and
2175+
# are returning a forced node with a different value, this is an
2176+
# equally bad misalignment.
2177+
if node.was_forced and not ir_value_equal(ir_type, forced, node.value):
2178+
self.mark_invalid()
2179+
21232180
return node
21242181

21252182
def as_result(self) -> Union[ConjectureResult, _Overrun]:

hypothesis-python/tests/conjecture/test_ir.py

+70-9
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

1111
import math
12+
from copy import deepcopy
1213

1314
import pytest
1415

15-
from hypothesis import assume, example, given, strategies as st
16+
from hypothesis import HealthCheck, assume, example, given, settings, strategies as st
1617
from hypothesis.errors import StopTest
1718
from hypothesis.internal.conjecture.data import (
1819
ConjectureData,
1920
IRNode,
2021
Status,
22+
ir_value_equal,
2123
ir_value_permitted,
2224
)
2325
from hypothesis.internal.conjecture.datatree import (
@@ -328,13 +330,12 @@ def test_ir_nodes(random):
328330

329331

330332
@st.composite
331-
def ir_nodes(draw):
333+
def ir_nodes(draw, *, was_forced=None):
332334
(ir_type, kwargs) = draw(ir_types_and_kwargs())
333335
value = draw_value(ir_type, kwargs)
336+
was_forced = draw(st.booleans()) if was_forced is None else was_forced
334337

335-
return IRNode(
336-
ir_type=ir_type, value=value, kwargs=kwargs, was_forced=draw(st.booleans())
337-
)
338+
return IRNode(ir_type=ir_type, value=value, kwargs=kwargs, was_forced=was_forced)
338339

339340

340341
@given(ir_nodes())
@@ -343,11 +344,17 @@ def test_copy_ir_node(node):
343344

344345
assume(not node.was_forced)
345346
new_value = draw_value(node.ir_type, node.kwargs)
346-
# if we drew the same value as before, the node should still be equal (unless nan)
347-
assume(
348-
node.ir_type != "float" or not (math.isnan(new_value) or math.isnan(node.value))
347+
# if we drew the same value as before, the node should still be equal
348+
assert (node.copy(with_value=new_value) == node) is (
349+
ir_value_equal(node.ir_type, new_value, node.value)
349350
)
350-
assert (node.copy(with_value=new_value) == node) is (new_value == node.value)
351+
352+
353+
@given(ir_nodes())
354+
def test_ir_node_equality(node):
355+
assert node == node
356+
# for coverage on our NotImplemented return, more than anything.
357+
assert node != 42
351358

352359

353360
def test_data_with_empty_ir_tree_is_overrun():
@@ -378,6 +385,60 @@ def test_data_with_misaligned_ir_tree_is_invalid(data):
378385
assert data.status is Status.INVALID
379386

380387

388+
@given(st.data())
389+
def test_data_with_changed_was_forced_is_invalid(data):
390+
# we had a normal node and then tried to draw a different forced value from it.
391+
# ir tree: v1 [was_forced=False]
392+
# drawing: [forced=v2]
393+
node = data.draw(ir_nodes(was_forced=False))
394+
data = ConjectureData.for_ir_tree([node])
395+
396+
draw_func = getattr(data, f"draw_{node.ir_type}")
397+
kwargs = deepcopy(node.kwargs)
398+
kwargs["forced"] = draw_value(node.ir_type, node.kwargs)
399+
assume(not ir_value_equal(node.ir_type, kwargs["forced"], node.value))
400+
401+
with pytest.raises(StopTest):
402+
draw_func(**kwargs)
403+
404+
assert data.status is Status.INVALID
405+
406+
407+
@given(st.data())
408+
@settings(suppress_health_check=[HealthCheck.too_slow])
409+
def test_data_with_changed_forced_value_is_invalid(data):
410+
# we had a forced node and then tried to draw a different forced value from it.
411+
# ir tree: v1 [was_forced=True]
412+
# drawing: [forced=v2]
413+
node = data.draw(ir_nodes(was_forced=True))
414+
data = ConjectureData.for_ir_tree([node])
415+
416+
draw_func = getattr(data, f"draw_{node.ir_type}")
417+
kwargs = deepcopy(node.kwargs)
418+
kwargs["forced"] = draw_value(node.ir_type, node.kwargs)
419+
assume(not ir_value_equal(node.ir_type, kwargs["forced"], node.value))
420+
421+
with pytest.raises(StopTest):
422+
draw_func(**kwargs)
423+
424+
assert data.status is Status.INVALID
425+
426+
427+
@given(st.data())
428+
def test_data_with_same_forced_value_is_valid(data):
429+
# we had a forced node and then drew the same forced value. This is totally
430+
# fine!
431+
# ir tree: v1 [was_forced=True]
432+
# drawing: [forced=v1]
433+
node = data.draw(ir_nodes(was_forced=True))
434+
data = ConjectureData.for_ir_tree([node])
435+
draw_func = getattr(data, f"draw_{node.ir_type}")
436+
437+
kwargs = deepcopy(node.kwargs)
438+
kwargs["forced"] = node.value
439+
draw_func(**kwargs)
440+
441+
381442
@given(ir_types_and_kwargs())
382443
def test_all_children_are_permitted_values(ir_type_and_kwargs):
383444
(ir_type, kwargs) = ir_type_and_kwargs

0 commit comments

Comments
 (0)