Skip to content

Commit fa0d507

Browse files
committed
Improve patches for np.array
1 parent 18cd732 commit fa0d507

File tree

4 files changed

+89
-12
lines changed

4 files changed

+89
-12
lines changed

hypothesis-python/scripts/other-tests.sh

+6-5
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ pip uninstall -y lark
5151
if [ "$(python -c $'import platform, sys; print(sys.version_info.releaselevel == \'final\' and platform.python_implementation() not in ("PyPy", "GraalVM"))')" = "True" ] ; then
5252
pip install ".[codemods,cli]"
5353
$PYTEST tests/codemods/
54-
pip install "$(grep -E 'black(==| @)' ../requirements/coverage.txt)"
55-
if [ "$(python -c 'import sys; print(sys.version_info[:2] >= (3, 9))')" = "True" ] ; then
56-
$PYTEST tests/patching/
57-
fi
58-
pip uninstall -y libcst
5954

6055
if [ "$(python -c 'import sys; print(sys.version_info[:2] == (3, 8))')" = "True" ] ; then
6156
# Per NEP-29, this is the last version to support Python 3.8
@@ -64,6 +59,12 @@ if [ "$(python -c $'import platform, sys; print(sys.version_info.releaselevel ==
6459
pip install "$(grep 'numpy==' ../requirements/coverage.txt)"
6560
fi
6661

62+
pip install "$(grep -E 'black(==| @)' ../requirements/coverage.txt)"
63+
if [ "$(python -c 'import sys; print(sys.version_info[:2] >= (3, 9))')" = "True" ] ; then
64+
$PYTEST tests/patching/
65+
fi
66+
pip uninstall -y libcst
67+
6768
$PYTEST tests/ghostwriter/
6869
pip uninstall -y black
6970

hypothesis-python/src/hypothesis/extra/_patching.py

+52-3
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
during fuzzing.
1919
"""
2020

21+
import ast
2122
import difflib
2223
import hashlib
2324
import inspect
2425
import re
2526
import sys
27+
import types
2628
from ast import literal_eval
2729
from contextlib import suppress
2830
from datetime import date, datetime, timedelta, timezone
@@ -31,6 +33,7 @@
3133
import libcst as cst
3234
from libcst import matchers as m
3335
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
36+
from libcst.metadata import ExpressionContext, ExpressionContextProvider
3437

3538
from hypothesis.configuration import storage_directory
3639
from hypothesis.version import __version__
@@ -148,18 +151,64 @@ def get_patch_for(func, failing_examples, *, strip_via=()):
148151
except Exception:
149152
return None
150153

154+
modules_in_test_scope = sorted(
155+
(
156+
(k, v)
157+
for (k, v) in module.__dict__.items()
158+
if isinstance(v, types.ModuleType)
159+
),
160+
key=lambda kv: len(kv[1].__name__),
161+
)
162+
151163
# The printed examples might include object reprs which are invalid syntax,
152164
# so we parse here and skip over those. If _none_ are valid, there's no patch.
153165
call_nodes = []
154166
for ex, via in set(failing_examples):
155167
with suppress(Exception):
156-
node = cst.parse_expression(ex)
157-
assert isinstance(node, cst.Call), node
168+
node = cst.parse_module(ex)
169+
the_call = node.body[0].body[0].value
170+
assert isinstance(the_call, cst.Call), the_call
158171
# Check for st.data(), which doesn't support explicit examples
159172
data = m.Arg(m.Call(m.Name("data"), args=[m.Arg(m.Ellipsis())]))
160-
if m.matches(node, m.Call(args=[m.ZeroOrMore(), data, m.ZeroOrMore()])):
173+
if m.matches(the_call, m.Call(args=[m.ZeroOrMore(), data, m.ZeroOrMore()])):
161174
return None
175+
176+
# Many reprs use the unqualified name of the type, e.g. np.array()
177+
# -> array([...]), so here we find undefined names and look them up
178+
# on each module which was in the test's global scope.
179+
names = {}
180+
for anode in ast.walk(ast.parse(ex, "eval")):
181+
if (
182+
isinstance(anode, ast.Name)
183+
and isinstance(anode.ctx, ast.Load)
184+
and anode.id not in names
185+
and anode.id not in module.__dict__
186+
):
187+
for k, v in modules_in_test_scope:
188+
if anode.id in v.__dict__:
189+
names[anode.id] = cst.parse_expression(f"{k}.{anode.id}")
190+
break
191+
192+
# LibCST doesn't track Load()/Store() state of names by default, so we have
193+
# to do a bit of a dance here, *and* explicitly handle keyword arguments
194+
# which are treated as Load() context - but even if that's fixed later
195+
# we'll still want to support older versions.
196+
with suppress(Exception):
197+
wrapper = cst.metadata.MetadataWrapper(node)
198+
kwarg_names = {
199+
a.keyword for a in m.findall(wrapper, m.Arg(keyword=m.Name()))
200+
}
201+
node = m.replace(
202+
wrapper,
203+
m.Name(value=m.MatchIfTrue(names.__contains__))
204+
& m.MatchMetadata(ExpressionContextProvider, ExpressionContext.LOAD)
205+
& m.MatchIfTrue(lambda n, k=kwarg_names: n not in k),
206+
replacement=lambda node, _, ns=names: ns[node.value],
207+
)
208+
node = node.body[0].body[0].value
209+
assert isinstance(node, cst.Call), node
162210
call_nodes.append((node, via))
211+
163212
if not call_nodes:
164213
return None
165214

hypothesis-python/tests/patching/callables.py

+8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212

1313
from pathlib import Path
1414

15+
import numpy as np
16+
1517
from hypothesis import example, given, strategies as st
18+
from hypothesis.extra import numpy as npst
1619

1720
WHERE = Path(__file__).relative_to(Path.cwd())
1821

@@ -36,4 +39,9 @@ def covered(x):
3639
"""A test function with a removable explicit example."""
3740

3841

42+
@given(npst.arrays(np.int8, 1))
43+
def undef_name(array):
44+
assert sum(array) < 100
45+
46+
3947
# TODO: test function for insertion-order logic, once I get that set up.

hypothesis-python/tests/patching/test_patching.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from hypothesis.internal.compat import WINDOWS
2525

26-
from .callables import WHERE, Cases, covered, fn
26+
from .callables import WHERE, Cases, covered, fn, undef_name
2727
from .toplevel import WHERE_TOP, fn_top
2828

2929
SIMPLE = (
@@ -52,6 +52,11 @@
5252
+ "\n"
5353
+ indent('@example(x=0).via("covering example")', prefix="+"),
5454
)
55+
UNDEF_NAME = (
56+
undef_name,
57+
("undef_name(\n array=array([100], dtype=int8),\n)", FAIL_MSG),
58+
'+@example(array=np.array([100], dtype=np.int8)).via("discovered failure")',
59+
)
5560

5661

5762
def strip_trailing_whitespace(s):
@@ -76,7 +81,7 @@ def test_adds_simple_patch(tst, example, expected):
7681
SIMPLE_PATCH_BODY = f'''\
7782
--- ./{WHERE}
7883
+++ ./{WHERE}
79-
@@ -18,6 +18,7 @@
84+
@@ -21,6 +21,7 @@
8085
8186
8287
@given(st.integers())
@@ -88,7 +93,7 @@ def fn(x):
8893
CASES_PATCH_BODY = f'''\
8994
--- ./{WHERE}
9095
+++ ./{WHERE}
91-
@@ -25,6 +25,9 @@
96+
@@ -28,6 +28,9 @@
9297
class Cases:
9398
@example(n=0, label="whatever")
9499
@given(st.integers(), st.text())
@@ -111,7 +116,7 @@ def fn_top(x):
111116
COVERING_PATCH_BODY = f'''\
112117
--- ./{WHERE}
113118
+++ ./{WHERE}
114-
@@ -31,7 +31,7 @@
119+
@@ -34,7 +34,7 @@
115120
116121
@given(st.integers())
117122
@example(x=2).via("not a literal when repeated " * 2)
@@ -121,6 +126,19 @@ def covered(x):
121126
122127
'''
123128

129+
UNDEF_NAME_PATCH_BODY = f"""\
130+
--- ./{WHERE}
131+
+++ ./{WHERE}
132+
@@ -40,6 +40,7 @@
133+
134+
135+
@given(npst.arrays(np.int8, 1))
136+
{{0}}
137+
def undef_name(array):
138+
assert sum(array) < 100
139+
140+
"""
141+
124142

125143
@pytest.mark.parametrize(
126144
"tst, example, expected, body, remove",
@@ -131,6 +149,7 @@ def covered(x):
131149
pytest.param(
132150
*COVERING, COVERING_PATCH_BODY, ("covering example",), id="covering"
133151
),
152+
pytest.param(*UNDEF_NAME, UNDEF_NAME_PATCH_BODY, (), id="undef_name"),
134153
],
135154
)
136155
def test_make_full_patch(tst, example, expected, body, remove):

0 commit comments

Comments
 (0)