Skip to content

Commit 208b862

Browse files
authored
Merge pull request #3354 from Cheukting/ghostwriter_imrpovement
Teach the Ghostwriter about @staticmethod and @classmethod methods
2 parents f035ab5 + 2c615d7 commit 208b862

File tree

10 files changed

+303
-26
lines changed

10 files changed

+303
-26
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
RELEASE_TYPE: minor
2+
3+
The :doc:`Ghostwritter <ghostwriter>` can now write tests for
4+
:obj:`@classmethod <classmethod>` or :obj:`@staticmethod <staticmethod>`
5+
methods, in addition to the existing support for functions and other callables
6+
(:issue:`3318`). Thanks to Cheuk Ting Ho for the patch.

hypothesis-python/src/hypothesis/extra/cli.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737

3838
import builtins
3939
import importlib
40+
import inspect
4041
import sys
42+
import types
4143
from difflib import get_close_matches
4244
from functools import partial
4345
from multiprocessing import Pool
@@ -84,27 +86,64 @@ def obj_name(s: str) -> object:
8486
return importlib.import_module(s)
8587
except ImportError:
8688
pass
89+
classname = None
8790
if "." not in s:
8891
modulename, module, funcname = "builtins", builtins, s
8992
else:
9093
modulename, funcname = s.rsplit(".", 1)
9194
try:
9295
module = importlib.import_module(modulename)
9396
except ImportError as err:
97+
try:
98+
modulename, classname = modulename.rsplit(".", 1)
99+
module = importlib.import_module(modulename)
100+
except ImportError:
101+
raise click.UsageError(
102+
f"Failed to import the {modulename} module for introspection. "
103+
"Check spelling and your Python import path, or use the Python API?"
104+
) from err
105+
106+
def describe_close_matches(
107+
module_or_class: types.ModuleType, objname: str
108+
) -> str:
109+
public_names = [
110+
name for name in vars(module_or_class) if not name.startswith("_")
111+
]
112+
matches = get_close_matches(objname, public_names)
113+
if matches:
114+
return f" Closest matches: {matches!r}"
115+
else:
116+
return ""
117+
118+
if classname is None:
119+
try:
120+
return getattr(module, funcname)
121+
except AttributeError as err:
94122
raise click.UsageError(
95-
f"Failed to import the {modulename} module for introspection. "
96-
"Check spelling and your Python import path, or use the Python API?"
123+
f"Found the {modulename!r} module, but it doesn't have a "
124+
f"{funcname!r} attribute."
125+
+ describe_close_matches(module, funcname)
126+
) from err
127+
else:
128+
try:
129+
func_class = getattr(module, classname)
130+
except AttributeError as err:
131+
raise click.UsageError(
132+
f"Found the {modulename!r} module, but it doesn't have a "
133+
f"{classname!r} class." + describe_close_matches(module, classname)
134+
) from err
135+
try:
136+
return getattr(func_class, funcname)
137+
except AttributeError as err:
138+
if inspect.isclass(func_class):
139+
func_class_is = "class"
140+
else:
141+
func_class_is = "attribute"
142+
raise click.UsageError(
143+
f"Found the {modulename!r} module and {classname!r} {func_class_is}, "
144+
f"but it doesn't have a {funcname!r} attribute."
145+
+ describe_close_matches(func_class, funcname)
97146
) from err
98-
try:
99-
return getattr(module, funcname)
100-
except AttributeError as err:
101-
public_names = [name for name in vars(module) if not name.startswith("_")]
102-
matches = get_close_matches(funcname, public_names)
103-
raise click.UsageError(
104-
f"Found the {modulename!r} module, but it doesn't have a "
105-
f"{funcname!r} attribute."
106-
+ (f" Closest matches: {matches!r}" if matches else "")
107-
) from err
108147

109148
def _refactor(func, fname):
110149
try:

hypothesis-python/src/hypothesis/extra/ghostwriter.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -891,10 +891,15 @@ def magic(
891891
for thing in modules_or_functions:
892892
if callable(thing):
893893
functions.add(thing)
894+
# class need to be added for exploration
895+
if inspect.isclass(thing):
896+
funcs: List[Optional[Any]] = [thing]
897+
else:
898+
funcs = []
894899
elif isinstance(thing, types.ModuleType):
895900
if hasattr(thing, "__all__"):
896901
funcs = [getattr(thing, name, None) for name in thing.__all__]
897-
else:
902+
elif hasattr(thing, "__package__"):
898903
pkg = thing.__package__
899904
funcs = [
900905
v
@@ -906,22 +911,35 @@ def magic(
906911
]
907912
if pkg and any(getattr(f, "__module__", pkg) == pkg for f in funcs):
908913
funcs = [f for f in funcs if getattr(f, "__module__", pkg) == pkg]
909-
for f in funcs:
910-
try:
911-
if (
912-
(not is_mock(f))
913-
and callable(f)
914-
and _get_params(f)
915-
and not isinstance(f, enum.EnumMeta)
916-
):
917-
functions.add(f)
918-
if getattr(thing, "__name__", None):
919-
KNOWN_FUNCTION_LOCATIONS[f] = thing.__name__
920-
except (TypeError, ValueError):
921-
pass
922914
else:
923915
raise InvalidArgument(f"Can't test non-module non-callable {thing!r}")
924916

917+
for f in list(funcs):
918+
if inspect.isclass(f):
919+
funcs += [
920+
v.__get__(f)
921+
for k, v in vars(f).items()
922+
if hasattr(v, "__func__")
923+
and not is_mock(v)
924+
and not k.startswith("_")
925+
]
926+
for f in funcs:
927+
try:
928+
if (
929+
(not is_mock(f))
930+
and callable(f)
931+
and _get_params(f)
932+
and not isinstance(f, enum.EnumMeta)
933+
):
934+
functions.add(f)
935+
if getattr(thing, "__name__", None):
936+
if inspect.isclass(thing):
937+
KNOWN_FUNCTION_LOCATIONS[f] = thing.__module__
938+
else:
939+
KNOWN_FUNCTION_LOCATIONS[f] = thing.__name__
940+
except (TypeError, ValueError):
941+
pass
942+
925943
imports = set()
926944
parts = []
927945

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# This test code was written by the `hypothesis.extra.ghostwriter` module
2+
# and is provided under the Creative Commons Zero public domain dedication.
3+
4+
import test_expected_output
5+
from hypothesis import given, strategies as st
6+
7+
8+
@given(arg=st.integers())
9+
def test_fuzz_A_Class_a_staticmethod(arg):
10+
test_expected_output.A_Class.a_staticmethod(arg=arg)

hypothesis-python/tests/ghostwriter/recorded/hypothesis_module_magic.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ def test_fuzz_settings(
9696
)
9797

9898

99+
@given(name=st.text())
100+
def test_fuzz_settings_get_profile(name):
101+
hypothesis.settings.get_profile(name=name)
102+
103+
104+
@given(name=st.text())
105+
def test_fuzz_settings_load_profile(name):
106+
hypothesis.settings.load_profile(name=name)
107+
108+
109+
@given(name=st.text(), parent=st.one_of(st.none(), st.builds(settings)))
110+
def test_fuzz_settings_register_profile(name, parent):
111+
hypothesis.settings.register_profile(name=name, parent=parent)
112+
113+
99114
@given(observation=st.one_of(st.floats(), st.integers()), label=st.text())
100115
def test_fuzz_target(observation, label):
101116
hypothesis.target(observation=observation, label=label)

hypothesis-python/tests/ghostwriter/recorded/magic_builtins.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ def test_fuzz_bin(number):
3131
bin(number)
3232

3333

34+
@given(frm=st.nothing(), to=st.nothing())
35+
def test_fuzz_bytearray_maketrans(frm, to):
36+
bytearray.maketrans(frm, to)
37+
38+
39+
@given(frm=st.nothing(), to=st.nothing())
40+
def test_fuzz_bytes_maketrans(frm, to):
41+
bytes.maketrans(frm, to)
42+
43+
3444
@given(obj=st.nothing())
3545
def test_fuzz_callable(obj):
3646
callable(obj)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# This test code was written by the `hypothesis.extra.ghostwriter` module
2+
# and is provided under the Creative Commons Zero public domain dedication.
3+
4+
import test_expected_output
5+
from hypothesis import given, strategies as st
6+
7+
8+
@given()
9+
def test_fuzz_A_Class():
10+
test_expected_output.A_Class()
11+
12+
13+
@given(arg=st.integers())
14+
def test_fuzz_A_Class_a_classmethod(arg):
15+
test_expected_output.A_Class.a_classmethod(arg=arg)
16+
17+
18+
@given(arg=st.integers())
19+
def test_fuzz_A_Class_a_staticmethod(arg):
20+
test_expected_output.A_Class.a_staticmethod(arg=arg)

hypothesis-python/tests/ghostwriter/test_expected_output.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ class A_Class:
6565
def a_classmethod(cls, arg: int):
6666
pass
6767

68+
@staticmethod
69+
def a_staticmethod(arg: int):
70+
pass
71+
6872

6973
def add(a: float, b: float) -> float:
7074
return a + b
@@ -86,6 +90,7 @@ def divide(a: int, b: int) -> float:
8690
("fuzz_sorted", ghostwriter.fuzz(sorted)),
8791
("fuzz_with_docstring", ghostwriter.fuzz(with_docstring)),
8892
("fuzz_classmethod", ghostwriter.fuzz(A_Class.a_classmethod)),
93+
("fuzz_staticmethod", ghostwriter.fuzz(A_Class.a_staticmethod)),
8994
("fuzz_ufunc", ghostwriter.fuzz(numpy.add)),
9095
("magic_gufunc", ghostwriter.magic(numpy.matmul)),
9196
("magic_base64_roundtrip", ghostwriter.magic(base64.b64encode)),
@@ -176,6 +181,7 @@ def divide(a: int, b: int) -> float:
176181
style="unittest",
177182
),
178183
),
184+
("magic_class", ghostwriter.magic(A_Class)),
179185
pytest.param(
180186
("magic_builtins", ghostwriter.magic(builtins)),
181187
marks=[

hypothesis-python/tests/ghostwriter/test_ghostwriter.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,34 @@ def test_invalid_func_inputs(gw, args):
257257
gw(*args)
258258

259259

260+
class A:
261+
@classmethod
262+
def to_json(cls, obj: Union[dict, list]) -> str:
263+
return json.dumps(obj)
264+
265+
@classmethod
266+
def from_json(cls, obj: str) -> Union[dict, list]:
267+
return json.loads(obj)
268+
269+
@staticmethod
270+
def static_sorter(seq: Sequence[int]) -> List[int]:
271+
return sorted(seq)
272+
273+
274+
@pytest.mark.parametrize(
275+
"gw,args",
276+
[
277+
(ghostwriter.fuzz, [A.static_sorter]),
278+
(ghostwriter.idempotent, [A.static_sorter]),
279+
(ghostwriter.roundtrip, [A.to_json, A.from_json]),
280+
(ghostwriter.equivalent, [A.to_json, json.dumps]),
281+
],
282+
)
283+
def test_class_methods_inputs(gw, args):
284+
source_code = gw(*args)
285+
get_test_function(source_code)()
286+
287+
260288
def test_run_ghostwriter_fuzz():
261289
# Our strategy-guessing code works for all the arguments to sorted,
262290
# and we handle positional-only arguments in calls correctly too.

0 commit comments

Comments
 (0)