6
6
import inspect
7
7
import socket
8
8
import warnings
9
+ from typing import (
10
+ Any ,
11
+ AsyncIterator ,
12
+ Awaitable ,
13
+ Callable ,
14
+ Dict ,
15
+ Iterable ,
16
+ Iterator ,
17
+ List ,
18
+ Optional ,
19
+ Set ,
20
+ TypeVar ,
21
+ Union ,
22
+ cast ,
23
+ overload ,
24
+ )
9
25
10
26
import pytest
27
+ from typing_extensions import Literal
28
+
29
+ _R = TypeVar ("_R" )
30
+
31
+ _ScopeName = Literal ["session" , "package" , "module" , "class" , "function" ]
32
+ _T = TypeVar ("_T" )
33
+
34
+ SimpleFixtureFunction = TypeVar (
35
+ "SimpleFixtureFunction" , bound = Callable [..., Awaitable [_R ]]
36
+ )
37
+ FactoryFixtureFunction = TypeVar (
38
+ "FactoryFixtureFunction" , bound = Callable [..., AsyncIterator [_R ]]
39
+ )
40
+ FixtureFunction = Union [SimpleFixtureFunction , FactoryFixtureFunction ]
41
+ FixtureFunctionMarker = Callable [[FixtureFunction ], FixtureFunction ]
42
+
43
+ Config = Any # pytest < 7.0
44
+ PytestPluginManager = Any # pytest < 7.0
45
+ FixtureDef = Any # pytest < 7.0
46
+ Parser = Any # pytest < 7.0
47
+ SubRequest = Any # pytest < 7.0
11
48
12
49
13
50
class Mode (str , enum .Enum ):
@@ -41,7 +78,7 @@ class Mode(str, enum.Enum):
41
78
"""
42
79
43
80
44
- def pytest_addoption (parser , pluginmanager ) :
81
+ def pytest_addoption (parser : Parser , pluginmanager : PytestPluginManager ) -> None :
45
82
group = parser .getgroup ("asyncio" )
46
83
group .addoption (
47
84
"--asyncio-mode" ,
@@ -58,49 +95,87 @@ def pytest_addoption(parser, pluginmanager):
58
95
)
59
96
60
97
61
- def fixture (fixture_function = None , ** kwargs ):
98
+ @overload
99
+ def fixture (
100
+ fixture_function : FixtureFunction ,
101
+ * ,
102
+ scope : "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
103
+ params : Optional [Iterable [object ]] = ...,
104
+ autouse : bool = ...,
105
+ ids : Optional [
106
+ Union [
107
+ Iterable [Union [None , str , float , int , bool ]],
108
+ Callable [[Any ], Optional [object ]],
109
+ ]
110
+ ] = ...,
111
+ name : Optional [str ] = ...,
112
+ ) -> FixtureFunction :
113
+ ...
114
+
115
+
116
+ @overload
117
+ def fixture (
118
+ fixture_function : None = ...,
119
+ * ,
120
+ scope : "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
121
+ params : Optional [Iterable [object ]] = ...,
122
+ autouse : bool = ...,
123
+ ids : Optional [
124
+ Union [
125
+ Iterable [Union [None , str , float , int , bool ]],
126
+ Callable [[Any ], Optional [object ]],
127
+ ]
128
+ ] = ...,
129
+ name : Optional [str ] = None ,
130
+ ) -> FixtureFunctionMarker :
131
+ ...
132
+
133
+
134
+ def fixture (
135
+ fixture_function : Optional [FixtureFunction ] = None , ** kwargs : Any
136
+ ) -> Union [FixtureFunction , FixtureFunctionMarker ]:
62
137
if fixture_function is not None :
63
138
_set_explicit_asyncio_mark (fixture_function )
64
139
return pytest .fixture (fixture_function , ** kwargs )
65
140
66
141
else :
67
142
68
143
@functools .wraps (fixture )
69
- def inner (fixture_function ) :
144
+ def inner (fixture_function : FixtureFunction ) -> FixtureFunction :
70
145
return fixture (fixture_function , ** kwargs )
71
146
72
147
return inner
73
148
74
149
75
- def _has_explicit_asyncio_mark (obj ) :
150
+ def _has_explicit_asyncio_mark (obj : Any ) -> bool :
76
151
obj = getattr (obj , "__func__" , obj ) # instance method maybe?
77
152
return getattr (obj , "_force_asyncio_fixture" , False )
78
153
79
154
80
- def _set_explicit_asyncio_mark (obj ) :
155
+ def _set_explicit_asyncio_mark (obj : Any ) -> None :
81
156
if hasattr (obj , "__func__" ):
82
157
# instance method, check the function object
83
158
obj = obj .__func__
84
159
obj ._force_asyncio_fixture = True
85
160
86
161
87
- def _is_coroutine (obj ) :
162
+ def _is_coroutine (obj : Any ) -> bool :
88
163
"""Check to see if an object is really an asyncio coroutine."""
89
164
return asyncio .iscoroutinefunction (obj ) or inspect .isgeneratorfunction (obj )
90
165
91
166
92
- def _is_coroutine_or_asyncgen (obj ) :
167
+ def _is_coroutine_or_asyncgen (obj : Any ) -> bool :
93
168
return _is_coroutine (obj ) or inspect .isasyncgenfunction (obj )
94
169
95
170
96
- def _get_asyncio_mode (config ) :
171
+ def _get_asyncio_mode (config : Config ) -> Mode :
97
172
val = config .getoption ("asyncio_mode" )
98
173
if val is None :
99
174
val = config .getini ("asyncio_mode" )
100
175
return Mode (val )
101
176
102
177
103
- def pytest_configure (config ) :
178
+ def pytest_configure (config : Config ) -> None :
104
179
"""Inject documentation."""
105
180
config .addinivalue_line (
106
181
"markers" ,
@@ -113,10 +188,14 @@ def pytest_configure(config):
113
188
114
189
115
190
@pytest .mark .tryfirst
116
- def pytest_pycollect_makeitem (collector , name , obj ):
191
+ def pytest_pycollect_makeitem (
192
+ collector : Union [pytest .Module , pytest .Class ], name : str , obj : object
193
+ ) -> Union [
194
+ None , pytest .Item , pytest .Collector , List [Union [pytest .Item , pytest .Collector ]]
195
+ ]:
117
196
"""A pytest hook to collect asyncio coroutines."""
118
197
if not collector .funcnamefilter (name ):
119
- return
198
+ return None
120
199
if (
121
200
_is_coroutine (obj )
122
201
or _is_hypothesis_test (obj )
@@ -131,10 +210,11 @@ def pytest_pycollect_makeitem(collector, name, obj):
131
210
ret = list (collector ._genfunctions (name , obj ))
132
211
for elem in ret :
133
212
elem .add_marker ("asyncio" )
134
- return ret
213
+ return ret # type: ignore[return-value]
214
+ return None
135
215
136
216
137
- def _hypothesis_test_wraps_coroutine (function ) :
217
+ def _hypothesis_test_wraps_coroutine (function : Any ) -> bool :
138
218
return _is_coroutine (function .hypothesis .inner_test )
139
219
140
220
@@ -144,19 +224,19 @@ class FixtureStripper:
144
224
REQUEST = "request"
145
225
EVENT_LOOP = "event_loop"
146
226
147
- def __init__ (self , fixturedef ) :
227
+ def __init__ (self , fixturedef : FixtureDef ) -> None :
148
228
self .fixturedef = fixturedef
149
- self .to_strip = set ()
229
+ self .to_strip : Set [ str ] = set ()
150
230
151
- def add (self , name ) :
231
+ def add (self , name : str ) -> None :
152
232
"""Add fixture name to fixturedef
153
233
and record in to_strip list (If not previously included)"""
154
234
if name in self .fixturedef .argnames :
155
235
return
156
236
self .fixturedef .argnames += (name ,)
157
237
self .to_strip .add (name )
158
238
159
- def get_and_strip_from (self , name , data_dict ) :
239
+ def get_and_strip_from (self , name : str , data_dict : Dict [ str , _T ]) -> _T :
160
240
"""Strip name from data, and return value"""
161
241
result = data_dict [name ]
162
242
if name in self .to_strip :
@@ -165,7 +245,7 @@ def get_and_strip_from(self, name, data_dict):
165
245
166
246
167
247
@pytest .hookimpl (trylast = True )
168
- def pytest_fixture_post_finalizer (fixturedef , request ) :
248
+ def pytest_fixture_post_finalizer (fixturedef : FixtureDef , request : SubRequest ) -> None :
169
249
"""Called after fixture teardown"""
170
250
if fixturedef .argname == "event_loop" :
171
251
policy = asyncio .get_event_loop_policy ()
@@ -182,7 +262,9 @@ def pytest_fixture_post_finalizer(fixturedef, request):
182
262
183
263
184
264
@pytest .hookimpl (hookwrapper = True )
185
- def pytest_fixture_setup (fixturedef , request ):
265
+ def pytest_fixture_setup (
266
+ fixturedef : FixtureDef , request : SubRequest
267
+ ) -> Optional [object ]:
186
268
"""Adjust the event loop policy when an event loop is produced."""
187
269
if fixturedef .argname == "event_loop" :
188
270
outcome = yield
@@ -295,39 +377,43 @@ async def setup():
295
377
296
378
297
379
@pytest .hookimpl (tryfirst = True , hookwrapper = True )
298
- def pytest_pyfunc_call (pyfuncitem ) :
380
+ def pytest_pyfunc_call (pyfuncitem : pytest . Function ) -> Optional [ object ] :
299
381
"""
300
382
Pytest hook called before a test case is run.
301
383
302
384
Wraps marked tests in a synchronous function
303
385
where the wrapped test coroutine is executed in an event loop.
304
386
"""
305
387
if "asyncio" in pyfuncitem .keywords :
388
+ funcargs : Dict [str , object ] = pyfuncitem .funcargs # type: ignore[name-defined]
389
+ loop = cast (asyncio .AbstractEventLoop , funcargs ["event_loop" ])
306
390
if _is_hypothesis_test (pyfuncitem .obj ):
307
391
pyfuncitem .obj .hypothesis .inner_test = wrap_in_sync (
308
392
pyfuncitem .obj .hypothesis .inner_test ,
309
- _loop = pyfuncitem . funcargs [ "event_loop" ] ,
393
+ _loop = loop ,
310
394
)
311
395
else :
312
396
pyfuncitem .obj = wrap_in_sync (
313
- pyfuncitem .obj , _loop = pyfuncitem .funcargs ["event_loop" ]
397
+ pyfuncitem .obj ,
398
+ _loop = loop ,
314
399
)
315
400
yield
316
401
317
402
318
- def _is_hypothesis_test (function ) -> bool :
403
+ def _is_hypothesis_test (function : Any ) -> bool :
319
404
return getattr (function , "is_hypothesis_test" , False )
320
405
321
406
322
- def wrap_in_sync (func , _loop ):
407
+ def wrap_in_sync (func : Callable [..., Awaitable [ Any ]], _loop : asyncio . AbstractEventLoop ):
323
408
"""Return a sync wrapper around an async function executing it in the
324
409
current event loop."""
325
410
326
411
# if the function is already wrapped, we rewrap using the original one
327
412
# not using __wrapped__ because the original function may already be
328
413
# a wrapped one
329
- if hasattr (func , "_raw_test_func" ):
330
- func = func ._raw_test_func
414
+ raw_func = getattr (func , "_raw_test_func" , None )
415
+ if raw_func is not None :
416
+ func = raw_func
331
417
332
418
@functools .wraps (func )
333
419
def inner (** kwargs ):
@@ -344,20 +430,22 @@ def inner(**kwargs):
344
430
task .exception ()
345
431
raise
346
432
347
- inner ._raw_test_func = func
433
+ inner ._raw_test_func = func # type: ignore[attr-defined]
348
434
return inner
349
435
350
436
351
- def pytest_runtest_setup (item ) :
437
+ def pytest_runtest_setup (item : pytest . Item ) -> None :
352
438
if "asyncio" in item .keywords :
439
+ fixturenames = item .fixturenames # type: ignore[attr-defined]
353
440
# inject an event loop fixture for all async tests
354
- if "event_loop" in item .fixturenames :
355
- item .fixturenames .remove ("event_loop" )
356
- item .fixturenames .insert (0 , "event_loop" )
441
+ if "event_loop" in fixturenames :
442
+ fixturenames .remove ("event_loop" )
443
+ fixturenames .insert (0 , "event_loop" )
444
+ obj = item .obj # type: ignore[attr-defined]
357
445
if (
358
446
item .get_closest_marker ("asyncio" ) is not None
359
- and not getattr (item . obj , "hypothesis" , False )
360
- and getattr (item . obj , "is_hypothesis_test" , False )
447
+ and not getattr (obj , "hypothesis" , False )
448
+ and getattr (obj , "is_hypothesis_test" , False )
361
449
):
362
450
pytest .fail (
363
451
"test function `%r` is using Hypothesis, but pytest-asyncio "
@@ -366,32 +454,32 @@ def pytest_runtest_setup(item):
366
454
367
455
368
456
@pytest .fixture
369
- def event_loop (request ) :
457
+ def event_loop (request : pytest . FixtureRequest ) -> Iterator [ asyncio . AbstractEventLoop ] :
370
458
"""Create an instance of the default event loop for each test case."""
371
459
loop = asyncio .get_event_loop_policy ().new_event_loop ()
372
460
yield loop
373
461
loop .close ()
374
462
375
463
376
- def _unused_port (socket_type ) :
464
+ def _unused_port (socket_type : int ) -> int :
377
465
"""Find an unused localhost port from 1024-65535 and return it."""
378
466
with contextlib .closing (socket .socket (type = socket_type )) as sock :
379
467
sock .bind (("127.0.0.1" , 0 ))
380
468
return sock .getsockname ()[1 ]
381
469
382
470
383
471
@pytest .fixture
384
- def unused_tcp_port ():
472
+ def unused_tcp_port () -> int :
385
473
return _unused_port (socket .SOCK_STREAM )
386
474
387
475
388
476
@pytest .fixture
389
- def unused_udp_port ():
477
+ def unused_udp_port () -> int :
390
478
return _unused_port (socket .SOCK_DGRAM )
391
479
392
480
393
481
@pytest .fixture (scope = "session" )
394
- def unused_tcp_port_factory ():
482
+ def unused_tcp_port_factory () -> Callable [[], int ] :
395
483
"""A factory function, producing different unused TCP ports."""
396
484
produced = set ()
397
485
@@ -410,7 +498,7 @@ def factory():
410
498
411
499
412
500
@pytest .fixture (scope = "session" )
413
- def unused_udp_port_factory ():
501
+ def unused_udp_port_factory () -> Callable [[], int ] :
414
502
"""A factory function, producing different unused UDP ports."""
415
503
produced = set ()
416
504
0 commit comments