Skip to content

Commit 8c6f4e4

Browse files
committed
Provide typing info
1 parent 6cc430c commit 8c6f4e4

File tree

4 files changed

+122
-39
lines changed

4 files changed

+122
-39
lines changed

Makefile

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ ifdef CI
2727
else
2828
pre-commit run --all-files
2929
endif
30+
mypy pytest_asyncio --show-error-codes
3031

3132
test:
3233
coverage run -m pytest tests

pytest_asyncio/plugin.py

+118-39
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,36 @@
66
import inspect
77
import socket
88
import warnings
9+
from typing import (
10+
Any,
11+
Awaitable,
12+
Callable,
13+
Dict,
14+
Iterable,
15+
Iterator,
16+
List,
17+
Optional,
18+
Set,
19+
TypeVar,
20+
Union,
21+
cast,
22+
overload,
23+
)
924

1025
import pytest
26+
from typing_extensions import Literal
27+
28+
_ScopeName = Literal["session", "package", "module", "class", "function"]
29+
_T = TypeVar("_T")
30+
31+
FixtureFunction = TypeVar("FixtureFunction", bound=Callable[..., object])
32+
FixtureFunctionMarker = Callable[[FixtureFunction], FixtureFunction]
33+
34+
Config = Any # pytest < 7.0
35+
PytestPluginManager = Any # pytest < 7.0
36+
FixtureDef = Any # pytest < 7.0
37+
Parser = Any # pytest < 7.0
38+
SubRequest = Any # pytest < 7.0
1139

1240

1341
class Mode(str, enum.Enum):
@@ -41,7 +69,7 @@ class Mode(str, enum.Enum):
4169
"""
4270

4371

44-
def pytest_addoption(parser, pluginmanager):
72+
def pytest_addoption(parser: Parser, pluginmanager: PytestPluginManager) -> None:
4573
group = parser.getgroup("asyncio")
4674
group.addoption(
4775
"--asyncio-mode",
@@ -58,49 +86,87 @@ def pytest_addoption(parser, pluginmanager):
5886
)
5987

6088

61-
def fixture(fixture_function=None, **kwargs):
89+
@overload
90+
def fixture(
91+
fixture_function: FixtureFunction,
92+
*,
93+
scope: "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
94+
params: Optional[Iterable[object]] = ...,
95+
autouse: bool = ...,
96+
ids: Optional[
97+
Union[
98+
Iterable[Union[None, str, float, int, bool]],
99+
Callable[[Any], Optional[object]],
100+
]
101+
] = ...,
102+
name: Optional[str] = ...,
103+
) -> FixtureFunction:
104+
...
105+
106+
107+
@overload
108+
def fixture(
109+
fixture_function: None = ...,
110+
*,
111+
scope: "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
112+
params: Optional[Iterable[object]] = ...,
113+
autouse: bool = ...,
114+
ids: Optional[
115+
Union[
116+
Iterable[Union[None, str, float, int, bool]],
117+
Callable[[Any], Optional[object]],
118+
]
119+
] = ...,
120+
name: Optional[str] = None,
121+
) -> FixtureFunctionMarker:
122+
...
123+
124+
125+
def fixture(
126+
fixture_function: Optional[FixtureFunction] = None, **kwargs: Any
127+
) -> Union[FixtureFunction, FixtureFunctionMarker]:
62128
if fixture_function is not None:
63129
_set_explicit_asyncio_mark(fixture_function)
64130
return pytest.fixture(fixture_function, **kwargs)
65131

66132
else:
67133

68134
@functools.wraps(fixture)
69-
def inner(fixture_function):
135+
def inner(fixture_function: FixtureFunction) -> FixtureFunction:
70136
return fixture(fixture_function, **kwargs)
71137

72138
return inner
73139

74140

75-
def _has_explicit_asyncio_mark(obj):
141+
def _has_explicit_asyncio_mark(obj: Any) -> bool:
76142
obj = getattr(obj, "__func__", obj) # instance method maybe?
77143
return getattr(obj, "_force_asyncio_fixture", False)
78144

79145

80-
def _set_explicit_asyncio_mark(obj):
146+
def _set_explicit_asyncio_mark(obj: Any) -> None:
81147
if hasattr(obj, "__func__"):
82148
# instance method, check the function object
83149
obj = obj.__func__
84150
obj._force_asyncio_fixture = True
85151

86152

87-
def _is_coroutine(obj):
153+
def _is_coroutine(obj: Any) -> bool:
88154
"""Check to see if an object is really an asyncio coroutine."""
89155
return asyncio.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)
90156

91157

92-
def _is_coroutine_or_asyncgen(obj):
158+
def _is_coroutine_or_asyncgen(obj: Any) -> bool:
93159
return _is_coroutine(obj) or inspect.isasyncgenfunction(obj)
94160

95161

96-
def _get_asyncio_mode(config):
162+
def _get_asyncio_mode(config: Config) -> Mode:
97163
val = config.getoption("asyncio_mode")
98164
if val is None:
99165
val = config.getini("asyncio_mode")
100166
return Mode(val)
101167

102168

103-
def pytest_configure(config):
169+
def pytest_configure(config: Config) -> None:
104170
"""Inject documentation."""
105171
config.addinivalue_line(
106172
"markers",
@@ -113,10 +179,14 @@ def pytest_configure(config):
113179

114180

115181
@pytest.mark.tryfirst
116-
def pytest_pycollect_makeitem(collector, name, obj):
182+
def pytest_pycollect_makeitem(
183+
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
184+
) -> Union[
185+
None, pytest.Item, pytest.Collector, List[Union[pytest.Item, pytest.Collector]]
186+
]:
117187
"""A pytest hook to collect asyncio coroutines."""
118188
if not collector.funcnamefilter(name):
119-
return
189+
return None
120190
if (
121191
_is_coroutine(obj)
122192
or _is_hypothesis_test(obj)
@@ -131,10 +201,11 @@ def pytest_pycollect_makeitem(collector, name, obj):
131201
ret = list(collector._genfunctions(name, obj))
132202
for elem in ret:
133203
elem.add_marker("asyncio")
134-
return ret
204+
return ret # type: ignore[return-value]
205+
return None
135206

136207

137-
def _hypothesis_test_wraps_coroutine(function):
208+
def _hypothesis_test_wraps_coroutine(function: Any) -> bool:
138209
return _is_coroutine(function.hypothesis.inner_test)
139210

140211

@@ -144,19 +215,19 @@ class FixtureStripper:
144215
REQUEST = "request"
145216
EVENT_LOOP = "event_loop"
146217

147-
def __init__(self, fixturedef):
218+
def __init__(self, fixturedef: FixtureDef) -> None:
148219
self.fixturedef = fixturedef
149-
self.to_strip = set()
220+
self.to_strip: Set[str] = set()
150221

151-
def add(self, name):
222+
def add(self, name: str) -> None:
152223
"""Add fixture name to fixturedef
153224
and record in to_strip list (If not previously included)"""
154225
if name in self.fixturedef.argnames:
155226
return
156227
self.fixturedef.argnames += (name,)
157228
self.to_strip.add(name)
158229

159-
def get_and_strip_from(self, name, data_dict):
230+
def get_and_strip_from(self, name: str, data_dict: Dict[str, _T]) -> _T:
160231
"""Strip name from data, and return value"""
161232
result = data_dict[name]
162233
if name in self.to_strip:
@@ -165,7 +236,7 @@ def get_and_strip_from(self, name, data_dict):
165236

166237

167238
@pytest.hookimpl(trylast=True)
168-
def pytest_fixture_post_finalizer(fixturedef, request):
239+
def pytest_fixture_post_finalizer(fixturedef: FixtureDef, request: SubRequest) -> None:
169240
"""Called after fixture teardown"""
170241
if fixturedef.argname == "event_loop":
171242
policy = asyncio.get_event_loop_policy()
@@ -177,7 +248,9 @@ def pytest_fixture_post_finalizer(fixturedef, request):
177248

178249

179250
@pytest.hookimpl(hookwrapper=True)
180-
def pytest_fixture_setup(fixturedef, request):
251+
def pytest_fixture_setup(
252+
fixturedef: FixtureDef, request: SubRequest
253+
) -> Optional[object]:
181254
"""Adjust the event loop policy when an event loop is produced."""
182255
if fixturedef.argname == "event_loop":
183256
outcome = yield
@@ -290,39 +363,43 @@ async def setup():
290363

291364

292365
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
293-
def pytest_pyfunc_call(pyfuncitem):
366+
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> Optional[object]:
294367
"""
295368
Pytest hook called before a test case is run.
296369
297370
Wraps marked tests in a synchronous function
298371
where the wrapped test coroutine is executed in an event loop.
299372
"""
300373
if "asyncio" in pyfuncitem.keywords:
374+
funcargs: Dict[str, object] = pyfuncitem.funcargs # type: ignore[name-defined]
375+
loop = cast(asyncio.AbstractEventLoop, funcargs["event_loop"])
301376
if _is_hypothesis_test(pyfuncitem.obj):
302377
pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync(
303378
pyfuncitem.obj.hypothesis.inner_test,
304-
_loop=pyfuncitem.funcargs["event_loop"],
379+
_loop=loop,
305380
)
306381
else:
307382
pyfuncitem.obj = wrap_in_sync(
308-
pyfuncitem.obj, _loop=pyfuncitem.funcargs["event_loop"]
383+
pyfuncitem.obj,
384+
_loop=loop,
309385
)
310386
yield
311387

312388

313-
def _is_hypothesis_test(function) -> bool:
389+
def _is_hypothesis_test(function: Any) -> bool:
314390
return getattr(function, "is_hypothesis_test", False)
315391

316392

317-
def wrap_in_sync(func, _loop):
393+
def wrap_in_sync(func: Callable[..., Awaitable[Any]], _loop: asyncio.AbstractEventLoop):
318394
"""Return a sync wrapper around an async function executing it in the
319395
current event loop."""
320396

321397
# if the function is already wrapped, we rewrap using the original one
322398
# not using __wrapped__ because the original function may already be
323399
# a wrapped one
324-
if hasattr(func, "_raw_test_func"):
325-
func = func._raw_test_func
400+
raw_func = getattr(func, "_raw_test_func", None)
401+
if raw_func is not None:
402+
func = raw_func
326403

327404
@functools.wraps(func)
328405
def inner(**kwargs):
@@ -339,20 +416,22 @@ def inner(**kwargs):
339416
task.exception()
340417
raise
341418

342-
inner._raw_test_func = func
419+
inner._raw_test_func = func # type: ignore[attr-defined]
343420
return inner
344421

345422

346-
def pytest_runtest_setup(item):
423+
def pytest_runtest_setup(item: pytest.Item) -> None:
347424
if "asyncio" in item.keywords:
425+
fixturenames = item.fixturenames # type: ignore[attr-defined]
348426
# inject an event loop fixture for all async tests
349-
if "event_loop" in item.fixturenames:
350-
item.fixturenames.remove("event_loop")
351-
item.fixturenames.insert(0, "event_loop")
427+
if "event_loop" in fixturenames:
428+
fixturenames.remove("event_loop")
429+
fixturenames.insert(0, "event_loop")
430+
obj = item.obj # type: ignore[attr-defined]
352431
if (
353432
item.get_closest_marker("asyncio") is not None
354-
and not getattr(item.obj, "hypothesis", False)
355-
and getattr(item.obj, "is_hypothesis_test", False)
433+
and not getattr(obj, "hypothesis", False)
434+
and getattr(obj, "is_hypothesis_test", False)
356435
):
357436
pytest.fail(
358437
"test function `%r` is using Hypothesis, but pytest-asyncio "
@@ -361,32 +440,32 @@ def pytest_runtest_setup(item):
361440

362441

363442
@pytest.fixture
364-
def event_loop(request):
443+
def event_loop(request: pytest.FixtureRequest) -> Iterator[asyncio.AbstractEventLoop]:
365444
"""Create an instance of the default event loop for each test case."""
366445
loop = asyncio.get_event_loop_policy().new_event_loop()
367446
yield loop
368447
loop.close()
369448

370449

371-
def _unused_port(socket_type):
450+
def _unused_port(socket_type: int) -> int:
372451
"""Find an unused localhost port from 1024-65535 and return it."""
373452
with contextlib.closing(socket.socket(type=socket_type)) as sock:
374453
sock.bind(("127.0.0.1", 0))
375454
return sock.getsockname()[1]
376455

377456

378457
@pytest.fixture
379-
def unused_tcp_port():
458+
def unused_tcp_port() -> int:
380459
return _unused_port(socket.SOCK_STREAM)
381460

382461

383462
@pytest.fixture
384-
def unused_udp_port():
463+
def unused_udp_port() -> int:
385464
return _unused_port(socket.SOCK_DGRAM)
386465

387466

388467
@pytest.fixture(scope="session")
389-
def unused_tcp_port_factory():
468+
def unused_tcp_port_factory() -> Callable[[], int]:
390469
"""A factory function, producing different unused TCP ports."""
391470
produced = set()
392471

@@ -405,7 +484,7 @@ def factory():
405484

406485

407486
@pytest.fixture(scope="session")
408-
def unused_udp_port_factory():
487+
def unused_udp_port_factory() -> Callable[[], int]:
409488
"""A factory function, producing different unused UDP ports."""
410489
produced = set()
411490

pytest_asyncio/py.typed

Whitespace-only changes.

setup.cfg

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ classifiers =
2727

2828
Framework :: AsyncIO
2929
Framework :: Pytest
30+
Typing :: Typed
3031

3132
[options]
3233
python_requires = >=3.7
@@ -38,12 +39,14 @@ setup_requires =
3839

3940
install_requires =
4041
pytest >= 5.4.0
42+
typing-extensions >= 4.0
4143

4244
[options.extras_require]
4345
testing =
4446
coverage==6.2
4547
hypothesis >= 5.7.1
4648
flaky >= 3.5.0
49+
mypy == 0.931
4750

4851
[options.entry_points]
4952
pytest11 =

0 commit comments

Comments
 (0)