|
18 | 18 | during fuzzing.
|
19 | 19 | """
|
20 | 20 |
|
| 21 | +import ast |
21 | 22 | import difflib
|
22 | 23 | import hashlib
|
23 | 24 | import inspect
|
24 | 25 | import re
|
25 | 26 | import sys
|
| 27 | +import types |
26 | 28 | from ast import literal_eval
|
27 | 29 | from contextlib import suppress
|
28 | 30 | from datetime import date, datetime, timedelta, timezone
|
|
31 | 33 | import libcst as cst
|
32 | 34 | from libcst import matchers as m
|
33 | 35 | from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
|
| 36 | +from libcst.metadata import ExpressionContext, ExpressionContextProvider |
34 | 37 |
|
35 | 38 | from hypothesis.configuration import storage_directory
|
36 | 39 | from hypothesis.version import __version__
|
@@ -148,18 +151,64 @@ def get_patch_for(func, failing_examples, *, strip_via=()):
|
148 | 151 | except Exception:
|
149 | 152 | return None
|
150 | 153 |
|
| 154 | + modules_in_test_scope = sorted( |
| 155 | + ( |
| 156 | + (k, v) |
| 157 | + for (k, v) in module.__dict__.items() |
| 158 | + if isinstance(v, types.ModuleType) |
| 159 | + ), |
| 160 | + key=lambda kv: len(kv[1].__name__), |
| 161 | + ) |
| 162 | + |
151 | 163 | # The printed examples might include object reprs which are invalid syntax,
|
152 | 164 | # so we parse here and skip over those. If _none_ are valid, there's no patch.
|
153 | 165 | call_nodes = []
|
154 | 166 | for ex, via in set(failing_examples):
|
155 | 167 | with suppress(Exception):
|
156 |
| - node = cst.parse_expression(ex) |
157 |
| - assert isinstance(node, cst.Call), node |
| 168 | + node = cst.parse_module(ex) |
| 169 | + the_call = node.body[0].body[0].value |
| 170 | + assert isinstance(the_call, cst.Call), the_call |
158 | 171 | # Check for st.data(), which doesn't support explicit examples
|
159 | 172 | data = m.Arg(m.Call(m.Name("data"), args=[m.Arg(m.Ellipsis())]))
|
160 |
| - if m.matches(node, m.Call(args=[m.ZeroOrMore(), data, m.ZeroOrMore()])): |
| 173 | + if m.matches(the_call, m.Call(args=[m.ZeroOrMore(), data, m.ZeroOrMore()])): |
161 | 174 | return None
|
| 175 | + |
| 176 | + # Many reprs use the unqualified name of the type, e.g. np.array() |
| 177 | + # -> array([...]), so here we find undefined names and look them up |
| 178 | + # on each module which was in the test's global scope. |
| 179 | + names = {} |
| 180 | + for anode in ast.walk(ast.parse(ex, "eval")): |
| 181 | + if ( |
| 182 | + isinstance(anode, ast.Name) |
| 183 | + and isinstance(anode.ctx, ast.Load) |
| 184 | + and anode.id not in names |
| 185 | + and anode.id not in module.__dict__ |
| 186 | + ): |
| 187 | + for k, v in modules_in_test_scope: |
| 188 | + if anode.id in v.__dict__: |
| 189 | + names[anode.id] = cst.parse_expression(f"{k}.{anode.id}") |
| 190 | + break |
| 191 | + |
| 192 | + # LibCST doesn't track Load()/Store() state of names by default, so we have |
| 193 | + # to do a bit of a dance here, *and* explicitly handle keyword arguments |
| 194 | + # which are treated as Load() context - but even if that's fixed later |
| 195 | + # we'll still want to support older versions. |
| 196 | + with suppress(Exception): |
| 197 | + wrapper = cst.metadata.MetadataWrapper(node) |
| 198 | + kwarg_names = { |
| 199 | + a.keyword for a in m.findall(wrapper, m.Arg(keyword=m.Name())) |
| 200 | + } |
| 201 | + node = m.replace( |
| 202 | + wrapper, |
| 203 | + m.Name(value=m.MatchIfTrue(names.__contains__)) |
| 204 | + & m.MatchMetadata(ExpressionContextProvider, ExpressionContext.LOAD) |
| 205 | + & m.MatchIfTrue(lambda n, k=kwarg_names: n not in k), |
| 206 | + replacement=lambda node, _, ns=names: ns[node.value], |
| 207 | + ) |
| 208 | + node = node.body[0].body[0].value |
| 209 | + assert isinstance(node, cst.Call), node |
162 | 210 | call_nodes.append((node, via))
|
| 211 | + |
163 | 212 | if not call_nodes:
|
164 | 213 | return None
|
165 | 214 |
|
@@ -205,8 +254,8 @@ def make_patch(triples, *, msg="Hypothesis: add explicit examples", when=None):
|
205 | 254 | ud = difflib.unified_diff(
|
206 | 255 | source_before.splitlines(keepends=True),
|
207 | 256 | source_after.splitlines(keepends=True),
|
208 |
| - fromfile=str(fname), |
209 |
| - tofile=str(fname), |
| 257 | + fromfile=f"./{fname}", # git strips the first part of the path by default |
| 258 | + tofile=f"./{fname}", |
210 | 259 | )
|
211 | 260 | diffs.append("".join(ud))
|
212 | 261 | return "".join(diffs)
|
|
0 commit comments