6
6
import inspect
7
7
import socket
8
8
import warnings
9
+ from typing import (
10
+ Any ,
11
+ Awaitable ,
12
+ Callable ,
13
+ Dict ,
14
+ Iterable ,
15
+ Iterator ,
16
+ List ,
17
+ Optional ,
18
+ Set ,
19
+ TypeVar ,
20
+ Union ,
21
+ cast ,
22
+ overload ,
23
+ )
9
24
10
25
import pytest
26
+ from typing_extensions import Literal
27
+
28
+ _ScopeName = Literal ["session" , "package" , "module" , "class" , "function" ]
29
+ _T = TypeVar ("_T" )
30
+
31
+ FixtureFunction = TypeVar ("FixtureFunction" , bound = Callable [..., object ])
32
+ FixtureFunctionMarker = Callable [[FixtureFunction ], FixtureFunction ]
33
+
34
+ Config = Any # pytest < 7.0
35
+ PytestPluginManager = Any # pytest < 7.0
36
+ FixtureDef = Any # pytest < 7.0
37
+ Parser = Any # pytest < 7.0
38
+ SubRequest = Any # pytest < 7.0
11
39
12
40
13
41
class Mode (str , enum .Enum ):
@@ -41,7 +69,7 @@ class Mode(str, enum.Enum):
41
69
"""
42
70
43
71
44
- def pytest_addoption (parser , pluginmanager ) :
72
+ def pytest_addoption (parser : Parser , pluginmanager : PytestPluginManager ) -> None :
45
73
group = parser .getgroup ("asyncio" )
46
74
group .addoption (
47
75
"--asyncio-mode" ,
@@ -58,49 +86,87 @@ def pytest_addoption(parser, pluginmanager):
58
86
)
59
87
60
88
61
- def fixture (fixture_function = None , ** kwargs ):
89
+ @overload
90
+ def fixture (
91
+ fixture_function : FixtureFunction ,
92
+ * ,
93
+ scope : "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
94
+ params : Optional [Iterable [object ]] = ...,
95
+ autouse : bool = ...,
96
+ ids : Optional [
97
+ Union [
98
+ Iterable [Union [None , str , float , int , bool ]],
99
+ Callable [[Any ], Optional [object ]],
100
+ ]
101
+ ] = ...,
102
+ name : Optional [str ] = ...,
103
+ ) -> FixtureFunction :
104
+ ...
105
+
106
+
107
+ @overload
108
+ def fixture (
109
+ fixture_function : None = ...,
110
+ * ,
111
+ scope : "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
112
+ params : Optional [Iterable [object ]] = ...,
113
+ autouse : bool = ...,
114
+ ids : Optional [
115
+ Union [
116
+ Iterable [Union [None , str , float , int , bool ]],
117
+ Callable [[Any ], Optional [object ]],
118
+ ]
119
+ ] = ...,
120
+ name : Optional [str ] = None ,
121
+ ) -> FixtureFunctionMarker :
122
+ ...
123
+
124
+
125
+ def fixture (
126
+ fixture_function : Optional [FixtureFunction ] = None , ** kwargs : Any
127
+ ) -> Union [FixtureFunction , FixtureFunctionMarker ]:
62
128
if fixture_function is not None :
63
129
_set_explicit_asyncio_mark (fixture_function )
64
130
return pytest .fixture (fixture_function , ** kwargs )
65
131
66
132
else :
67
133
68
134
@functools .wraps (fixture )
69
- def inner (fixture_function ) :
135
+ def inner (fixture_function : FixtureFunction ) -> FixtureFunction :
70
136
return fixture (fixture_function , ** kwargs )
71
137
72
138
return inner
73
139
74
140
75
- def _has_explicit_asyncio_mark (obj ) :
141
+ def _has_explicit_asyncio_mark (obj : Any ) -> bool :
76
142
obj = getattr (obj , "__func__" , obj ) # instance method maybe?
77
143
return getattr (obj , "_force_asyncio_fixture" , False )
78
144
79
145
80
- def _set_explicit_asyncio_mark (obj ) :
146
+ def _set_explicit_asyncio_mark (obj : Any ) -> None :
81
147
if hasattr (obj , "__func__" ):
82
148
# instance method, check the function object
83
149
obj = obj .__func__
84
150
obj ._force_asyncio_fixture = True
85
151
86
152
87
- def _is_coroutine (obj ) :
153
+ def _is_coroutine (obj : Any ) -> bool :
88
154
"""Check to see if an object is really an asyncio coroutine."""
89
155
return asyncio .iscoroutinefunction (obj ) or inspect .isgeneratorfunction (obj )
90
156
91
157
92
- def _is_coroutine_or_asyncgen (obj ) :
158
+ def _is_coroutine_or_asyncgen (obj : Any ) -> bool :
93
159
return _is_coroutine (obj ) or inspect .isasyncgenfunction (obj )
94
160
95
161
96
- def _get_asyncio_mode (config ) :
162
+ def _get_asyncio_mode (config : Config ) -> Mode :
97
163
val = config .getoption ("asyncio_mode" )
98
164
if val is None :
99
165
val = config .getini ("asyncio_mode" )
100
166
return Mode (val )
101
167
102
168
103
- def pytest_configure (config ) :
169
+ def pytest_configure (config : Config ) -> None :
104
170
"""Inject documentation."""
105
171
config .addinivalue_line (
106
172
"markers" ,
@@ -113,10 +179,14 @@ def pytest_configure(config):
113
179
114
180
115
181
@pytest .mark .tryfirst
116
- def pytest_pycollect_makeitem (collector , name , obj ):
182
+ def pytest_pycollect_makeitem (
183
+ collector : Union [pytest .Module , pytest .Class ], name : str , obj : object
184
+ ) -> Union [
185
+ None , pytest .Item , pytest .Collector , List [Union [pytest .Item , pytest .Collector ]]
186
+ ]:
117
187
"""A pytest hook to collect asyncio coroutines."""
118
188
if not collector .funcnamefilter (name ):
119
- return
189
+ return None
120
190
if (
121
191
_is_coroutine (obj )
122
192
or _is_hypothesis_test (obj )
@@ -131,10 +201,11 @@ def pytest_pycollect_makeitem(collector, name, obj):
131
201
ret = list (collector ._genfunctions (name , obj ))
132
202
for elem in ret :
133
203
elem .add_marker ("asyncio" )
134
- return ret
204
+ return ret # type: ignore[return-value]
205
+ return None
135
206
136
207
137
- def _hypothesis_test_wraps_coroutine (function ) :
208
+ def _hypothesis_test_wraps_coroutine (function : Any ) -> bool :
138
209
return _is_coroutine (function .hypothesis .inner_test )
139
210
140
211
@@ -144,19 +215,19 @@ class FixtureStripper:
144
215
REQUEST = "request"
145
216
EVENT_LOOP = "event_loop"
146
217
147
- def __init__ (self , fixturedef ) :
218
+ def __init__ (self , fixturedef : FixtureDef ) -> None :
148
219
self .fixturedef = fixturedef
149
- self .to_strip = set ()
220
+ self .to_strip : Set [ str ] = set ()
150
221
151
- def add (self , name ) :
222
+ def add (self , name : str ) -> None :
152
223
"""Add fixture name to fixturedef
153
224
and record in to_strip list (If not previously included)"""
154
225
if name in self .fixturedef .argnames :
155
226
return
156
227
self .fixturedef .argnames += (name ,)
157
228
self .to_strip .add (name )
158
229
159
- def get_and_strip_from (self , name , data_dict ) :
230
+ def get_and_strip_from (self , name : str , data_dict : Dict [ str , _T ]) -> _T :
160
231
"""Strip name from data, and return value"""
161
232
result = data_dict [name ]
162
233
if name in self .to_strip :
@@ -165,7 +236,7 @@ def get_and_strip_from(self, name, data_dict):
165
236
166
237
167
238
@pytest .hookimpl (trylast = True )
168
- def pytest_fixture_post_finalizer (fixturedef , request ) :
239
+ def pytest_fixture_post_finalizer (fixturedef : FixtureDef , request : SubRequest ) -> None :
169
240
"""Called after fixture teardown"""
170
241
if fixturedef .argname == "event_loop" :
171
242
policy = asyncio .get_event_loop_policy ()
@@ -177,7 +248,9 @@ def pytest_fixture_post_finalizer(fixturedef, request):
177
248
178
249
179
250
@pytest .hookimpl (hookwrapper = True )
180
- def pytest_fixture_setup (fixturedef , request ):
251
+ def pytest_fixture_setup (
252
+ fixturedef : FixtureDef , request : SubRequest
253
+ ) -> Optional [object ]:
181
254
"""Adjust the event loop policy when an event loop is produced."""
182
255
if fixturedef .argname == "event_loop" :
183
256
outcome = yield
@@ -290,39 +363,43 @@ async def setup():
290
363
291
364
292
365
@pytest .hookimpl (tryfirst = True , hookwrapper = True )
293
- def pytest_pyfunc_call (pyfuncitem ) :
366
+ def pytest_pyfunc_call (pyfuncitem : pytest . Function ) -> Optional [ object ] :
294
367
"""
295
368
Pytest hook called before a test case is run.
296
369
297
370
Wraps marked tests in a synchronous function
298
371
where the wrapped test coroutine is executed in an event loop.
299
372
"""
300
373
if "asyncio" in pyfuncitem .keywords :
374
+ funcargs : Dict [str , object ] = pyfuncitem .funcargs # type: ignore[name-defined]
375
+ loop = cast (asyncio .AbstractEventLoop , funcargs ["event_loop" ])
301
376
if _is_hypothesis_test (pyfuncitem .obj ):
302
377
pyfuncitem .obj .hypothesis .inner_test = wrap_in_sync (
303
378
pyfuncitem .obj .hypothesis .inner_test ,
304
- _loop = pyfuncitem . funcargs [ "event_loop" ] ,
379
+ _loop = loop ,
305
380
)
306
381
else :
307
382
pyfuncitem .obj = wrap_in_sync (
308
- pyfuncitem .obj , _loop = pyfuncitem .funcargs ["event_loop" ]
383
+ pyfuncitem .obj ,
384
+ _loop = loop ,
309
385
)
310
386
yield
311
387
312
388
313
- def _is_hypothesis_test (function ) -> bool :
389
+ def _is_hypothesis_test (function : Any ) -> bool :
314
390
return getattr (function , "is_hypothesis_test" , False )
315
391
316
392
317
- def wrap_in_sync (func , _loop ):
393
+ def wrap_in_sync (func : Callable [..., Awaitable [ Any ]], _loop : asyncio . AbstractEventLoop ):
318
394
"""Return a sync wrapper around an async function executing it in the
319
395
current event loop."""
320
396
321
397
# if the function is already wrapped, we rewrap using the original one
322
398
# not using __wrapped__ because the original function may already be
323
399
# a wrapped one
324
- if hasattr (func , "_raw_test_func" ):
325
- func = func ._raw_test_func
400
+ raw_func = getattr (func , "_raw_test_func" , None )
401
+ if raw_func is not None :
402
+ func = raw_func
326
403
327
404
@functools .wraps (func )
328
405
def inner (** kwargs ):
@@ -339,20 +416,22 @@ def inner(**kwargs):
339
416
task .exception ()
340
417
raise
341
418
342
- inner ._raw_test_func = func
419
+ inner ._raw_test_func = func # type: ignore[attr-defined]
343
420
return inner
344
421
345
422
346
- def pytest_runtest_setup (item ) :
423
+ def pytest_runtest_setup (item : pytest . Item ) -> None :
347
424
if "asyncio" in item .keywords :
425
+ fixturenames = item .fixturenames # type: ignore[attr-defined]
348
426
# inject an event loop fixture for all async tests
349
- if "event_loop" in item .fixturenames :
350
- item .fixturenames .remove ("event_loop" )
351
- item .fixturenames .insert (0 , "event_loop" )
427
+ if "event_loop" in fixturenames :
428
+ fixturenames .remove ("event_loop" )
429
+ fixturenames .insert (0 , "event_loop" )
430
+ obj = item .obj # type: ignore[attr-defined]
352
431
if (
353
432
item .get_closest_marker ("asyncio" ) is not None
354
- and not getattr (item . obj , "hypothesis" , False )
355
- and getattr (item . obj , "is_hypothesis_test" , False )
433
+ and not getattr (obj , "hypothesis" , False )
434
+ and getattr (obj , "is_hypothesis_test" , False )
356
435
):
357
436
pytest .fail (
358
437
"test function `%r` is using Hypothesis, but pytest-asyncio "
@@ -361,32 +440,32 @@ def pytest_runtest_setup(item):
361
440
362
441
363
442
@pytest .fixture
364
- def event_loop (request ) :
443
+ def event_loop (request : pytest . FixtureRequest ) -> Iterator [ asyncio . AbstractEventLoop ] :
365
444
"""Create an instance of the default event loop for each test case."""
366
445
loop = asyncio .get_event_loop_policy ().new_event_loop ()
367
446
yield loop
368
447
loop .close ()
369
448
370
449
371
- def _unused_port (socket_type ) :
450
+ def _unused_port (socket_type : int ) -> int :
372
451
"""Find an unused localhost port from 1024-65535 and return it."""
373
452
with contextlib .closing (socket .socket (type = socket_type )) as sock :
374
453
sock .bind (("127.0.0.1" , 0 ))
375
454
return sock .getsockname ()[1 ]
376
455
377
456
378
457
@pytest .fixture
379
- def unused_tcp_port ():
458
+ def unused_tcp_port () -> int :
380
459
return _unused_port (socket .SOCK_STREAM )
381
460
382
461
383
462
@pytest .fixture
384
- def unused_udp_port ():
463
+ def unused_udp_port () -> int :
385
464
return _unused_port (socket .SOCK_DGRAM )
386
465
387
466
388
467
@pytest .fixture (scope = "session" )
389
- def unused_tcp_port_factory ():
468
+ def unused_tcp_port_factory () -> Callable [[], int ] :
390
469
"""A factory function, producing different unused TCP ports."""
391
470
produced = set ()
392
471
@@ -405,7 +484,7 @@ def factory():
405
484
406
485
407
486
@pytest .fixture (scope = "session" )
408
- def unused_udp_port_factory ():
487
+ def unused_udp_port_factory () -> Callable [[], int ] :
409
488
"""A factory function, producing different unused UDP ports."""
410
489
produced = set ()
411
490
0 commit comments