Skip to content

Commit df28a74

Browse files
committed
Update some type inference code
1 parent 1861342 commit df28a74

File tree

4 files changed

+94
-52
lines changed

4 files changed

+94
-52
lines changed

hypothesis-python/src/hypothesis/strategies/_internal/core.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from functools import reduce
2626
from inspect import Parameter, Signature, isabstract, isclass
2727
from re import Pattern
28-
from types import FunctionType
28+
from types import FunctionType, GenericAlias
2929
from typing import (
3030
Any,
3131
AnyStr,
@@ -1326,6 +1326,13 @@ def from_type_guarded(thing):
13261326
strategy = as_strategy(types._global_type_lookup[thing], thing)
13271327
if strategy is not NotImplemented:
13281328
return strategy
1329+
elif (
1330+
isinstance(thing, GenericAlias)
1331+
and (to := get_origin(thing)) in types._global_type_lookup
1332+
):
1333+
strategy = as_strategy(types._global_type_lookup[to], thing)
1334+
if strategy is not NotImplemented:
1335+
return strategy
13291336
except TypeError: # pragma: no cover
13301337
# This was originally due to a bizarre divergence in behaviour on Python 3.9.0:
13311338
# typing.Callable[[], foo] has __args__ = (foo,) but collections.abc.Callable

hypothesis-python/src/hypothesis/strategies/_internal/types.py

+27-21
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@
2828
import uuid
2929
import warnings
3030
import zoneinfo
31+
from collections.abc import Iterator
3132
from functools import partial
3233
from pathlib import PurePath
3334
from types import FunctionType
34-
from typing import TYPE_CHECKING, Any, Iterator, Tuple, get_args, get_origin
35+
from typing import TYPE_CHECKING, Any, get_args, get_origin
3536

3637
from hypothesis import strategies as st
3738
from hypothesis.errors import HypothesisWarning, InvalidArgument, ResolutionFailed
@@ -339,7 +340,7 @@ def get_constraints_filter_map():
339340
return {} # pragma: no cover
340341

341342

342-
def _get_constraints(args: Tuple[Any, ...]) -> Iterator["at.BaseMetadata"]:
343+
def _get_constraints(args: tuple[Any, ...]) -> Iterator["at.BaseMetadata"]:
343344
at = sys.modules.get("annotated_types")
344345
for arg in args:
345346
if at and isinstance(arg, at.BaseMetadata):
@@ -619,7 +620,7 @@ def _networks(bits):
619620
# exposed for it, and NotImplemented itself is typed as Any so that it can be
620621
# returned without being listed in a function signature:
621622
# https://github.com/python/mypy/issues/6710#issuecomment-485580032
622-
_global_type_lookup: typing.Dict[
623+
_global_type_lookup: dict[
623624
type, typing.Union[st.SearchStrategy, typing.Callable[[type], st.SearchStrategy]]
624625
] = {
625626
type(None): st.none(),
@@ -726,8 +727,8 @@ def _networks(bits):
726727
_global_type_lookup[builtins.sequenceiterator] = st.builds(iter, st.tuples()) # type: ignore
727728

728729

729-
_global_type_lookup[type] = st.sampled_from(
730-
[type(None), *sorted(_global_type_lookup, key=str)]
730+
_fallback_type_strategy = st.sampled_from(
731+
sorted(_global_type_lookup, key=type_sorting_key)
731732
)
732733
# subclass of MutableMapping, and so we resolve to a union which
733734
# includes this... but we don't actually ever want to build one.
@@ -803,15 +804,15 @@ def _networks(bits):
803804
# installed. To avoid the performance hit of importing anything here, we defer
804805
# it until the method is called the first time, at which point we replace the
805806
# entry in the lookup table with the direct call.
806-
def _from_numpy_type(thing: typing.Type) -> typing.Optional[st.SearchStrategy]:
807+
def _from_numpy_type(thing: type) -> typing.Optional[st.SearchStrategy]:
807808
from hypothesis.extra.numpy import _from_type
808809

809810
_global_extra_lookup["numpy"] = _from_type
810811
return _from_type(thing)
811812

812813

813-
_global_extra_lookup: typing.Dict[
814-
str, typing.Callable[[typing.Type], typing.Optional[st.SearchStrategy]]
814+
_global_extra_lookup: dict[
815+
str, typing.Callable[[type], typing.Optional[st.SearchStrategy]]
815816
] = {
816817
"numpy": _from_numpy_type,
817818
}
@@ -839,26 +840,30 @@ def really_inner(thing):
839840
return fallback
840841
return func(thing)
841842

843+
_global_type_lookup[type_] = really_inner
842844
_global_type_lookup[get_origin(type_) or type_] = really_inner
843845
return really_inner
844846

845847
return inner
846848

847849

848-
@register(typing.Type)
850+
@register(type)
851+
@register("Type")
849852
@register("Type", module=typing_extensions)
850853
def resolve_Type(thing):
851854
if getattr(thing, "__args__", None) is None:
852855
return st.just(type)
856+
elif get_args(thing) == (): # pragma: no cover
857+
return _fallback_type_strategy
853858
args = (thing.__args__[0],)
854859
if is_a_union(args[0]):
855860
args = args[0].__args__
856861
# Duplicate check from from_type here - only paying when needed.
857862
args = list(args)
858863
for i, a in enumerate(args):
859-
if type(a) == typing.ForwardRef:
864+
if type(a) in (typing.ForwardRef, str):
860865
try:
861-
args[i] = getattr(builtins, a.__forward_arg__)
866+
args[i] = getattr(builtins, getattr(a, "__forward_arg__", a))
862867
except AttributeError:
863868
raise ResolutionFailed(
864869
f"Cannot find the type referenced by {thing} - try using "
@@ -867,12 +872,12 @@ def resolve_Type(thing):
867872
return st.sampled_from(sorted(args, key=type_sorting_key))
868873

869874

870-
@register(typing.List, st.builds(list))
875+
@register("List", st.builds(list))
871876
def resolve_List(thing):
872877
return st.lists(st.from_type(thing.__args__[0]))
873878

874879

875-
@register(typing.Tuple, st.builds(tuple))
880+
@register("Tuple", st.builds(tuple))
876881
def resolve_Tuple(thing):
877882
elem_types = getattr(thing, "__args__", None) or ()
878883
if len(elem_types) == 2 and elem_types[-1] is Ellipsis:
@@ -906,27 +911,28 @@ def _from_hashable_type(type_):
906911
return st.from_type(type_).filter(_can_hash)
907912

908913

909-
@register(typing.Set, st.builds(set))
914+
@register("Set", st.builds(set))
910915
@register(typing.MutableSet, st.builds(set))
911916
def resolve_Set(thing):
912917
return st.sets(_from_hashable_type(thing.__args__[0]))
913918

914919

915-
@register(typing.FrozenSet, st.builds(frozenset))
920+
@register("FrozenSet", st.builds(frozenset))
916921
def resolve_FrozenSet(thing):
917922
return st.frozensets(_from_hashable_type(thing.__args__[0]))
918923

919924

920-
@register(typing.Dict, st.builds(dict))
925+
@register("Dict", st.builds(dict))
921926
def resolve_Dict(thing):
922927
# If thing is a Collection instance, we need to fill in the values
923-
keys_vals = thing.__args__ * 2
928+
keys, vals, *_ = thing.__args__ * 2
924929
return st.dictionaries(
925-
_from_hashable_type(keys_vals[0]), st.from_type(keys_vals[1])
930+
_from_hashable_type(keys),
931+
st.none() if vals is None else st.from_type(vals),
926932
)
927933

928934

929-
@register(typing.DefaultDict, st.builds(collections.defaultdict))
935+
@register("DefaultDict", st.builds(collections.defaultdict))
930936
@register("DefaultDict", st.builds(collections.defaultdict), module=typing_extensions)
931937
def resolve_DefaultDict(thing):
932938
return resolve_Dict(thing).map(lambda d: collections.defaultdict(None, d))
@@ -988,9 +994,9 @@ def resolve_Pattern(thing):
988994
return st.just(re.compile(thing.__args__[0]()))
989995

990996

991-
@register( # pragma: no branch # coverage does not see lambda->exit branch
997+
@register(
992998
typing.Match,
993-
st.text().map(lambda c: re.match(".", c, flags=re.DOTALL)).filter(bool),
999+
st.text().map(partial(re.match, ".", flags=re.DOTALL)).filter(bool),
9941000
)
9951001
def resolve_Match(thing):
9961002
if thing.__args__[0] == bytes:

hypothesis-python/tests/cover/test_lookup.py

+58-27
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@
6666
),
6767
key=str,
6868
)
69+
_Type = getattr(typing, "Type", None)
70+
_List = getattr(typing, "List", None)
71+
_Dict = getattr(typing, "Dict", None)
72+
_Set = getattr(typing, "Set", None)
73+
_FrozenSet = getattr(typing, "FrozenSet", None)
74+
_Tuple = getattr(typing, "Tuple", None)
6975

7076

7177
@pytest.mark.parametrize("typ", generics, ids=repr)
@@ -104,10 +110,13 @@ def test_specialised_scalar_types(data, typ, instance_of):
104110

105111

106112
def test_typing_Type_int():
107-
assert_simple_property(from_type(typing.Type[int]), lambda x: x is int)
113+
for t in (type[int], type["int"], _Type[int], _Type["int"]):
114+
assert_simple_property(from_type(t), lambda x: x is int)
108115

109116

110-
@given(from_type(typing.Type[typing.Union[str, list]]))
117+
@given(
118+
from_type(type[typing.Union[str, list]]) | from_type(_Type[typing.Union[str, list]])
119+
)
111120
def test_typing_Type_Union(ex):
112121
assert ex in (str, list)
113122

@@ -143,15 +152,21 @@ class Elem:
143152
@pytest.mark.parametrize(
144153
"typ,coll_type",
145154
[
146-
(typing.Set[Elem], set),
147-
(typing.FrozenSet[Elem], frozenset),
148-
(typing.Dict[Elem, None], dict),
155+
(_Set[Elem], set),
156+
(_FrozenSet[Elem], frozenset),
157+
(_Dict[Elem, None], dict),
158+
(set[Elem], set),
159+
(frozenset[Elem], frozenset),
160+
# (dict[Elem, None], dict), # FIXME this should work
149161
(typing.DefaultDict[Elem, None], collections.defaultdict),
150162
(typing.KeysView[Elem], type({}.keys())),
151163
(typing.ValuesView[Elem], type({}.values())),
152-
(typing.List[Elem], list),
153-
(typing.Tuple[Elem], tuple),
154-
(typing.Tuple[Elem, ...], tuple),
164+
(_List[Elem], list),
165+
(_Tuple[Elem], tuple),
166+
(_Tuple[Elem, ...], tuple),
167+
(list[Elem], list),
168+
(tuple[Elem], tuple),
169+
(tuple[Elem, ...], tuple),
155170
(typing.Iterator[Elem], typing.Iterator),
156171
(typing.Sequence[Elem], typing.Sequence),
157172
(typing.Iterable[Elem], typing.Iterable),
@@ -226,23 +241,24 @@ def test_Optional_minimises_to_None():
226241
assert minimal(from_type(typing.Optional[int]), lambda ex: True) is None
227242

228243

229-
@pytest.mark.parametrize("n", range(10))
230-
def test_variable_length_tuples(n):
231-
type_ = typing.Tuple[int, ...]
244+
@pytest.mark.parametrize("n", [0, 1, 5])
245+
@pytest.mark.parametrize("t", [tuple, _Tuple])
246+
def test_variable_length_tuples(t, n):
247+
type_ = t[int, ...]
232248
check_can_generate_examples(from_type(type_).filter(lambda ex: len(ex) == n))
233249

234250

235251
def test_lookup_overrides_defaults():
236252
sentinel = object()
237253
with temp_registered(int, st.just(sentinel)):
238254

239-
@given(from_type(typing.List[int]))
255+
@given(from_type(list[int]))
240256
def inner_1(ex):
241257
assert all(elem is sentinel for elem in ex)
242258

243259
inner_1()
244260

245-
@given(from_type(typing.List[int]))
261+
@given(from_type(list[int]))
246262
def inner_2(ex):
247263
assert all(isinstance(elem, int) for elem in ex)
248264

@@ -253,7 +269,7 @@ def test_register_generic_typing_strats():
253269
# I don't expect anyone to do this, but good to check it works as expected
254270
with temp_registered(
255271
typing.Sequence,
256-
types._global_type_lookup[typing.get_origin(typing.Set) or typing.Set],
272+
types._global_type_lookup[set],
257273
):
258274
# We register sets for the abstract sequence type, which masks subtypes
259275
# from supertype resolution but not direct resolution
@@ -264,9 +280,7 @@ def test_register_generic_typing_strats():
264280
from_type(typing.Container[int]),
265281
lambda ex: not isinstance(ex, typing.Sequence),
266282
)
267-
assert_all_examples(
268-
from_type(typing.List[int]), lambda ex: isinstance(ex, list)
269-
)
283+
assert_all_examples(from_type(list[int]), lambda ex: isinstance(ex, list))
270284

271285

272286
def if_available(name):
@@ -587,7 +601,7 @@ def test_override_args_for_namedtuple(thing):
587601
assert thing.a is None
588602

589603

590-
@pytest.mark.parametrize("thing", [typing.Optional, typing.List, typing.Type])
604+
@pytest.mark.parametrize("thing", [typing.Optional, list, type, _List, _Type])
591605
def test_cannot_resolve_bare_forward_reference(thing):
592606
t = thing["ConcreteFoo"]
593607
with pytest.raises(InvalidArgument):
@@ -740,7 +754,7 @@ def test_resolving_recursive_type_with_registered_constraint_not_none():
740754
find_any(s, lambda s: s.next_node is not None)
741755

742756

743-
@given(from_type(typing.Tuple[()]))
757+
@given(from_type(tuple[()]) | from_type(_Tuple[()]))
744758
def test_resolves_empty_Tuple_issue_1583_regression(ex):
745759
# See e.g. https://github.com/python/mypy/commit/71332d58
746760
assert ex == ()
@@ -805,11 +819,17 @@ def test_cannot_resolve_abstract_class_with_no_concrete_subclass(instance):
805819

806820

807821
@fails_with(ResolutionFailed)
808-
@given(st.from_type(typing.Type["ConcreteFoo"]))
822+
@given(st.from_type(type["ConcreteFoo"]))
809823
def test_cannot_resolve_type_with_forwardref(instance):
810824
raise AssertionError("test body unreachable as strategy cannot resolve")
811825

812826

827+
@fails_with(ResolutionFailed)
828+
@given(st.from_type(_Type["ConcreteFoo"]))
829+
def test_cannot_resolve_type_with_forwardref_old(instance):
830+
raise AssertionError("test body unreachable as strategy cannot resolve")
831+
832+
813833
@pytest.mark.parametrize("typ", [typing.Hashable, typing.Sized])
814834
@given(data=st.data())
815835
def test_inference_on_generic_collections_abc_aliases(typ, data):
@@ -938,9 +958,12 @@ def test_timezone_lookup(type_):
938958
@pytest.mark.parametrize(
939959
"typ",
940960
[
941-
typing.Set[typing.Hashable],
942-
typing.FrozenSet[typing.Hashable],
943-
typing.Dict[typing.Hashable, int],
961+
_Set[typing.Hashable],
962+
_FrozenSet[typing.Hashable],
963+
_Dict[typing.Hashable, int],
964+
set[typing.Hashable],
965+
frozenset[typing.Hashable],
966+
dict[typing.Hashable, int],
944967
],
945968
)
946969
@settings(suppress_health_check=[HealthCheck.data_too_large])
@@ -973,7 +996,8 @@ def __init__(self, value=-1) -> None:
973996
"typ,repr_",
974997
[
975998
(int, "integers()"),
976-
(typing.List[str], "lists(text())"),
999+
(list[str], "lists(text())"),
1000+
(_List[str], "lists(text())"),
9771001
("not a type", "from_type('not a type')"),
9781002
(random.Random, "randoms()"),
9791003
(_EmptyClass, "from_type(tests.cover.test_lookup._EmptyClass)"),
@@ -1123,15 +1147,22 @@ def test_resolves_forwardrefs_to_builtin_types(t, data):
11231147

11241148
@pytest.mark.parametrize("t", BUILTIN_TYPES, ids=lambda t: t.__name__)
11251149
def test_resolves_type_of_builtin_types(t):
1126-
assert_simple_property(st.from_type(typing.Type[t.__name__]), lambda v: v is t)
1150+
assert_simple_property(st.from_type(type[t.__name__]), lambda v: v is t)
11271151

11281152

1129-
@given(st.from_type(typing.Type[typing.Union["str", "int"]]))
1153+
@given(
1154+
st.from_type(type[typing.Union["str", "int"]])
1155+
| st.from_type(_Type[typing.Union["str", "int"]])
1156+
)
11301157
def test_resolves_type_of_union_of_forwardrefs_to_builtins(x):
11311158
assert x in (str, int)
11321159

11331160

1134-
@pytest.mark.parametrize("type_", [typing.List[int], typing.Optional[int]])
1161+
@pytest.mark.parametrize(
1162+
# Old-style `List` because `list[int]() == list()`, so no need for the hint.
1163+
"type_",
1164+
[getattr(typing, "List", None)[int], typing.Optional[int]],
1165+
)
11351166
def test_builds_suggests_from_type(type_):
11361167
with pytest.raises(
11371168
InvalidArgument, match=re.escape(f"try using from_type({type_!r})")

tooling/src/hypothesistooling/__main__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,7 @@ def warn(msg):
208208

209209
codespell("--write-changes", *files_to_format, *doc_files_to_format)
210210
pip_tool("ruff", "check", "--fix-only", ".")
211-
pip_tool("shed", *files_to_format, *doc_files_to_format)
212-
# FIXME: work through the typing issues and enable py39 formatting
213-
# pip_tool("shed", "--py39-plus", *files_to_format, *doc_files_to_format)
211+
pip_tool("shed", "--py39-plus", *files_to_format, *doc_files_to_format)
214212

215213

216214
VALID_STARTS = (HEADER.split()[0], "#!/usr/bin/env python")

0 commit comments

Comments
 (0)