6
6
import inspect
7
7
import socket
8
8
import warnings
9
+ from typing import (
10
+ Any ,
11
+ AsyncIterator ,
12
+ Awaitable ,
13
+ Callable ,
14
+ Dict ,
15
+ Iterable ,
16
+ Iterator ,
17
+ List ,
18
+ Optional ,
19
+ Set ,
20
+ TypeVar ,
21
+ Union ,
22
+ cast ,
23
+ overload ,
24
+ )
9
25
10
26
import pytest
27
+ from typing_extensions import Literal
28
+
29
+ _R = TypeVar ("_R" )
30
+
31
+ _ScopeName = Literal ["session" , "package" , "module" , "class" , "function" ]
32
+ _T = TypeVar ("_T" )
33
+
34
+ SimpleFixtureFunction = TypeVar (
35
+ "SimpleFixtureFunction" , bound = Callable [..., Awaitable [_R ]]
36
+ )
37
+ FactoryFixtureFunction = TypeVar (
38
+ "FactoryFixtureFunction" , bound = Callable [..., AsyncIterator [_R ]]
39
+ )
40
+ FixtureFunction = Union [SimpleFixtureFunction , FactoryFixtureFunction ]
41
+ FixtureFunctionMarker = Callable [[FixtureFunction ], FixtureFunction ]
42
+
43
+ Config = Any # pytest < 7.0
44
+ PytestPluginManager = Any # pytest < 7.0
45
+ FixtureDef = Any # pytest < 7.0
46
+ Parser = Any # pytest < 7.0
47
+ SubRequest = Any # pytest < 7.0
11
48
12
49
13
50
class Mode (str , enum .Enum ):
@@ -41,7 +78,7 @@ class Mode(str, enum.Enum):
41
78
"""
42
79
43
80
44
- def pytest_addoption (parser , pluginmanager ) :
81
+ def pytest_addoption (parser : Parser , pluginmanager : PytestPluginManager ) -> None :
45
82
group = parser .getgroup ("asyncio" )
46
83
group .addoption (
47
84
"--asyncio-mode" ,
@@ -57,49 +94,87 @@ def pytest_addoption(parser, pluginmanager):
57
94
)
58
95
59
96
60
- def fixture (fixture_function = None , ** kwargs ):
97
+ @overload
98
+ def fixture (
99
+ fixture_function : FixtureFunction ,
100
+ * ,
101
+ scope : "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
102
+ params : Optional [Iterable [object ]] = ...,
103
+ autouse : bool = ...,
104
+ ids : Optional [
105
+ Union [
106
+ Iterable [Union [None , str , float , int , bool ]],
107
+ Callable [[Any ], Optional [object ]],
108
+ ]
109
+ ] = ...,
110
+ name : Optional [str ] = ...,
111
+ ) -> FixtureFunction :
112
+ ...
113
+
114
+
115
+ @overload
116
+ def fixture (
117
+ fixture_function : None = ...,
118
+ * ,
119
+ scope : "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
120
+ params : Optional [Iterable [object ]] = ...,
121
+ autouse : bool = ...,
122
+ ids : Optional [
123
+ Union [
124
+ Iterable [Union [None , str , float , int , bool ]],
125
+ Callable [[Any ], Optional [object ]],
126
+ ]
127
+ ] = ...,
128
+ name : Optional [str ] = None ,
129
+ ) -> FixtureFunctionMarker :
130
+ ...
131
+
132
+
133
+ def fixture (
134
+ fixture_function : Optional [FixtureFunction ] = None , ** kwargs : Any
135
+ ) -> Union [FixtureFunction , FixtureFunctionMarker ]:
61
136
if fixture_function is not None :
62
137
_set_explicit_asyncio_mark (fixture_function )
63
138
return pytest .fixture (fixture_function , ** kwargs )
64
139
65
140
else :
66
141
67
142
@functools .wraps (fixture )
68
- def inner (fixture_function ) :
143
+ def inner (fixture_function : FixtureFunction ) -> FixtureFunction :
69
144
return fixture (fixture_function , ** kwargs )
70
145
71
146
return inner
72
147
73
148
74
- def _has_explicit_asyncio_mark (obj ) :
149
+ def _has_explicit_asyncio_mark (obj : Any ) -> bool :
75
150
obj = getattr (obj , "__func__" , obj ) # instance method maybe?
76
151
return getattr (obj , "_force_asyncio_fixture" , False )
77
152
78
153
79
- def _set_explicit_asyncio_mark (obj ) :
154
+ def _set_explicit_asyncio_mark (obj : Any ) -> None :
80
155
if hasattr (obj , "__func__" ):
81
156
# instance method, check the function object
82
157
obj = obj .__func__
83
158
obj ._force_asyncio_fixture = True
84
159
85
160
86
- def _is_coroutine (obj ) :
161
+ def _is_coroutine (obj : Any ) -> bool :
87
162
"""Check to see if an object is really an asyncio coroutine."""
88
163
return asyncio .iscoroutinefunction (obj ) or inspect .isgeneratorfunction (obj )
89
164
90
165
91
- def _is_coroutine_or_asyncgen (obj ) :
166
+ def _is_coroutine_or_asyncgen (obj : Any ) -> bool :
92
167
return _is_coroutine (obj ) or inspect .isasyncgenfunction (obj )
93
168
94
169
95
- def _get_asyncio_mode (config ) :
170
+ def _get_asyncio_mode (config : Config ) -> Mode :
96
171
val = config .getoption ("asyncio_mode" )
97
172
if val is None :
98
173
val = config .getini ("asyncio_mode" )
99
174
return Mode (val )
100
175
101
176
102
- def pytest_configure (config ) :
177
+ def pytest_configure (config : Config ) -> None :
103
178
"""Inject documentation."""
104
179
config .addinivalue_line (
105
180
"markers" ,
@@ -112,10 +187,14 @@ def pytest_configure(config):
112
187
113
188
114
189
@pytest .mark .tryfirst
115
- def pytest_pycollect_makeitem (collector , name , obj ):
190
+ def pytest_pycollect_makeitem (
191
+ collector : Union [pytest .Module , pytest .Class ], name : str , obj : object
192
+ ) -> Union [
193
+ None , pytest .Item , pytest .Collector , List [Union [pytest .Item , pytest .Collector ]]
194
+ ]:
116
195
"""A pytest hook to collect asyncio coroutines."""
117
196
if not collector .funcnamefilter (name ):
118
- return
197
+ return None
119
198
if (
120
199
_is_coroutine (obj )
121
200
or _is_hypothesis_test (obj )
@@ -130,10 +209,11 @@ def pytest_pycollect_makeitem(collector, name, obj):
130
209
ret = list (collector ._genfunctions (name , obj ))
131
210
for elem in ret :
132
211
elem .add_marker ("asyncio" )
133
- return ret
212
+ return ret # type: ignore[return-value]
213
+ return None
134
214
135
215
136
- def _hypothesis_test_wraps_coroutine (function ) :
216
+ def _hypothesis_test_wraps_coroutine (function : Any ) -> bool :
137
217
return _is_coroutine (function .hypothesis .inner_test )
138
218
139
219
@@ -143,19 +223,19 @@ class FixtureStripper:
143
223
REQUEST = "request"
144
224
EVENT_LOOP = "event_loop"
145
225
146
- def __init__ (self , fixturedef ) :
226
+ def __init__ (self , fixturedef : FixtureDef ) -> None :
147
227
self .fixturedef = fixturedef
148
- self .to_strip = set ()
228
+ self .to_strip : Set [ str ] = set ()
149
229
150
- def add (self , name ) :
230
+ def add (self , name : str ) -> None :
151
231
"""Add fixture name to fixturedef
152
232
and record in to_strip list (If not previously included)"""
153
233
if name in self .fixturedef .argnames :
154
234
return
155
235
self .fixturedef .argnames += (name ,)
156
236
self .to_strip .add (name )
157
237
158
- def get_and_strip_from (self , name , data_dict ) :
238
+ def get_and_strip_from (self , name : str , data_dict : Dict [ str , _T ]) -> _T :
159
239
"""Strip name from data, and return value"""
160
240
result = data_dict [name ]
161
241
if name in self .to_strip :
@@ -164,7 +244,7 @@ def get_and_strip_from(self, name, data_dict):
164
244
165
245
166
246
@pytest .hookimpl (trylast = True )
167
- def pytest_fixture_post_finalizer (fixturedef , request ) :
247
+ def pytest_fixture_post_finalizer (fixturedef : FixtureDef , request : SubRequest ) -> None :
168
248
"""Called after fixture teardown"""
169
249
if fixturedef .argname == "event_loop" :
170
250
policy = asyncio .get_event_loop_policy ()
@@ -181,7 +261,9 @@ def pytest_fixture_post_finalizer(fixturedef, request):
181
261
182
262
183
263
@pytest .hookimpl (hookwrapper = True )
184
- def pytest_fixture_setup (fixturedef , request ):
264
+ def pytest_fixture_setup (
265
+ fixturedef : FixtureDef , request : SubRequest
266
+ ) -> Optional [object ]:
185
267
"""Adjust the event loop policy when an event loop is produced."""
186
268
if fixturedef .argname == "event_loop" :
187
269
outcome = yield
@@ -294,39 +376,43 @@ async def setup():
294
376
295
377
296
378
@pytest .hookimpl (tryfirst = True , hookwrapper = True )
297
- def pytest_pyfunc_call (pyfuncitem ) :
379
+ def pytest_pyfunc_call (pyfuncitem : pytest . Function ) -> Optional [ object ] :
298
380
"""
299
381
Pytest hook called before a test case is run.
300
382
301
383
Wraps marked tests in a synchronous function
302
384
where the wrapped test coroutine is executed in an event loop.
303
385
"""
304
386
if "asyncio" in pyfuncitem .keywords :
387
+ funcargs : Dict [str , object ] = pyfuncitem .funcargs # type: ignore[name-defined]
388
+ loop = cast (asyncio .AbstractEventLoop , funcargs ["event_loop" ])
305
389
if _is_hypothesis_test (pyfuncitem .obj ):
306
390
pyfuncitem .obj .hypothesis .inner_test = wrap_in_sync (
307
391
pyfuncitem .obj .hypothesis .inner_test ,
308
- _loop = pyfuncitem . funcargs [ "event_loop" ] ,
392
+ _loop = loop ,
309
393
)
310
394
else :
311
395
pyfuncitem .obj = wrap_in_sync (
312
- pyfuncitem .obj , _loop = pyfuncitem .funcargs ["event_loop" ]
396
+ pyfuncitem .obj ,
397
+ _loop = loop ,
313
398
)
314
399
yield
315
400
316
401
317
- def _is_hypothesis_test (function ) -> bool :
402
+ def _is_hypothesis_test (function : Any ) -> bool :
318
403
return getattr (function , "is_hypothesis_test" , False )
319
404
320
405
321
- def wrap_in_sync (func , _loop ):
406
+ def wrap_in_sync (func : Callable [..., Awaitable [ Any ]], _loop : asyncio . AbstractEventLoop ):
322
407
"""Return a sync wrapper around an async function executing it in the
323
408
current event loop."""
324
409
325
410
# if the function is already wrapped, we rewrap using the original one
326
411
# not using __wrapped__ because the original function may already be
327
412
# a wrapped one
328
- if hasattr (func , "_raw_test_func" ):
329
- func = func ._raw_test_func
413
+ raw_func = getattr (func , "_raw_test_func" , None )
414
+ if raw_func is not None :
415
+ func = raw_func
330
416
331
417
@functools .wraps (func )
332
418
def inner (** kwargs ):
@@ -343,20 +429,22 @@ def inner(**kwargs):
343
429
task .exception ()
344
430
raise
345
431
346
- inner ._raw_test_func = func
432
+ inner ._raw_test_func = func # type: ignore[attr-defined]
347
433
return inner
348
434
349
435
350
- def pytest_runtest_setup (item ) :
436
+ def pytest_runtest_setup (item : pytest . Item ) -> None :
351
437
if "asyncio" in item .keywords :
438
+ fixturenames = item .fixturenames # type: ignore[attr-defined]
352
439
# inject an event loop fixture for all async tests
353
- if "event_loop" in item .fixturenames :
354
- item .fixturenames .remove ("event_loop" )
355
- item .fixturenames .insert (0 , "event_loop" )
440
+ if "event_loop" in fixturenames :
441
+ fixturenames .remove ("event_loop" )
442
+ fixturenames .insert (0 , "event_loop" )
443
+ obj = item .obj # type: ignore[attr-defined]
356
444
if (
357
445
item .get_closest_marker ("asyncio" ) is not None
358
- and not getattr (item . obj , "hypothesis" , False )
359
- and getattr (item . obj , "is_hypothesis_test" , False )
446
+ and not getattr (obj , "hypothesis" , False )
447
+ and getattr (obj , "is_hypothesis_test" , False )
360
448
):
361
449
pytest .fail (
362
450
"test function `%r` is using Hypothesis, but pytest-asyncio "
@@ -365,32 +453,32 @@ def pytest_runtest_setup(item):
365
453
366
454
367
455
@pytest .fixture
368
- def event_loop (request ) :
456
+ def event_loop (request : pytest . FixtureRequest ) -> Iterator [ asyncio . AbstractEventLoop ] :
369
457
"""Create an instance of the default event loop for each test case."""
370
458
loop = asyncio .get_event_loop_policy ().new_event_loop ()
371
459
yield loop
372
460
loop .close ()
373
461
374
462
375
- def _unused_port (socket_type ) :
463
+ def _unused_port (socket_type : int ) -> int :
376
464
"""Find an unused localhost port from 1024-65535 and return it."""
377
465
with contextlib .closing (socket .socket (type = socket_type )) as sock :
378
466
sock .bind (("127.0.0.1" , 0 ))
379
467
return sock .getsockname ()[1 ]
380
468
381
469
382
470
@pytest .fixture
383
- def unused_tcp_port ():
471
+ def unused_tcp_port () -> int :
384
472
return _unused_port (socket .SOCK_STREAM )
385
473
386
474
387
475
@pytest .fixture
388
- def unused_udp_port ():
476
+ def unused_udp_port () -> int :
389
477
return _unused_port (socket .SOCK_DGRAM )
390
478
391
479
392
480
@pytest .fixture (scope = "session" )
393
- def unused_tcp_port_factory ():
481
+ def unused_tcp_port_factory () -> Callable [[], int ] :
394
482
"""A factory function, producing different unused TCP ports."""
395
483
produced = set ()
396
484
@@ -409,7 +497,7 @@ def factory():
409
497
410
498
411
499
@pytest .fixture (scope = "session" )
412
- def unused_udp_port_factory ():
500
+ def unused_udp_port_factory () -> Callable [[], int ] :
413
501
"""A factory function, producing different unused UDP ports."""
414
502
produced = set ()
415
503
0 commit comments