Skip to content

Commit d8efa64

Browse files
authored
Support parametrized event_loop fixture (#278)
1 parent dab3b51 commit d8efa64

File tree

3 files changed

+146
-115
lines changed

3 files changed

+146
-115
lines changed

README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ Changelog
261261
~~~~~~~~~~~~~~~~~~~
262262

263263
- Raise a warning if @pytest.mark.asyncio is applied to non-async function. `#275 <https://github.com/pytest-dev/pytest-asyncio/issues/275>`_
264+
- Support parametrized ``event_loop`` fixture. `#278 <https://github.com/pytest-dev/pytest-asyncio/issues/278>`_
264265

265266
0.17.2 (22-01-17)
266267
~~~~~~~~~~~~~~~~~~~

pytest_asyncio/plugin.py

+114-115
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _set_explicit_asyncio_mark(obj: Any) -> None:
165165

166166
def _is_coroutine(obj: Any) -> bool:
167167
"""Check to see if an object is really an asyncio coroutine."""
168-
return asyncio.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)
168+
return asyncio.iscoroutinefunction(obj)
169169

170170

171171
def _is_coroutine_or_asyncgen(obj: Any) -> bool:
@@ -198,6 +198,118 @@ def pytest_report_header(config: Config) -> List[str]:
198198
return [f"asyncio: mode={mode}"]
199199

200200

201+
def _preprocess_async_fixtures(config: Config, holder: Set[FixtureDef]) -> None:
202+
asyncio_mode = _get_asyncio_mode(config)
203+
fixturemanager = config.pluginmanager.get_plugin("funcmanage")
204+
for fixtures in fixturemanager._arg2fixturedefs.values():
205+
for fixturedef in fixtures:
206+
if fixturedef is holder:
207+
continue
208+
func = fixturedef.func
209+
if not _is_coroutine_or_asyncgen(func):
210+
# Nothing to do with a regular fixture function
211+
continue
212+
if not _has_explicit_asyncio_mark(func):
213+
if asyncio_mode == Mode.AUTO:
214+
# Enforce asyncio mode if 'auto'
215+
_set_explicit_asyncio_mark(func)
216+
elif asyncio_mode == Mode.LEGACY:
217+
_set_explicit_asyncio_mark(func)
218+
try:
219+
code = func.__code__
220+
except AttributeError:
221+
code = func.__func__.__code__
222+
name = (
223+
f"<fixture {func.__qualname__}, file={code.co_filename}, "
224+
f"line={code.co_firstlineno}>"
225+
)
226+
warnings.warn(
227+
LEGACY_ASYNCIO_FIXTURE.format(name=name),
228+
DeprecationWarning,
229+
)
230+
231+
to_add = []
232+
for name in ("request", "event_loop"):
233+
if name not in fixturedef.argnames:
234+
to_add.append(name)
235+
236+
if to_add:
237+
fixturedef.argnames += tuple(to_add)
238+
239+
if inspect.isasyncgenfunction(func):
240+
fixturedef.func = _wrap_asyncgen(func)
241+
elif inspect.iscoroutinefunction(func):
242+
fixturedef.func = _wrap_async(func)
243+
244+
assert _has_explicit_asyncio_mark(fixturedef.func)
245+
holder.add(fixturedef)
246+
247+
248+
def _add_kwargs(
249+
func: Callable[..., Any],
250+
kwargs: Dict[str, Any],
251+
event_loop: asyncio.AbstractEventLoop,
252+
request: SubRequest,
253+
) -> Dict[str, Any]:
254+
sig = inspect.signature(func)
255+
ret = kwargs.copy()
256+
if "request" in sig.parameters:
257+
ret["request"] = request
258+
if "event_loop" in sig.parameters:
259+
ret["event_loop"] = event_loop
260+
return ret
261+
262+
263+
def _wrap_asyncgen(func: Callable[..., AsyncIterator[_R]]) -> Callable[..., _R]:
264+
@functools.wraps(func)
265+
def _asyncgen_fixture_wrapper(
266+
event_loop: asyncio.AbstractEventLoop, request: SubRequest, **kwargs: Any
267+
) -> _R:
268+
gen_obj = func(**_add_kwargs(func, kwargs, event_loop, request))
269+
270+
async def setup() -> _R:
271+
res = await gen_obj.__anext__()
272+
return res
273+
274+
def finalizer() -> None:
275+
"""Yield again, to finalize."""
276+
277+
async def async_finalizer() -> None:
278+
try:
279+
await gen_obj.__anext__()
280+
except StopAsyncIteration:
281+
pass
282+
else:
283+
msg = "Async generator fixture didn't stop."
284+
msg += "Yield only once."
285+
raise ValueError(msg)
286+
287+
event_loop.run_until_complete(async_finalizer())
288+
289+
result = event_loop.run_until_complete(setup())
290+
request.addfinalizer(finalizer)
291+
return result
292+
293+
return _asyncgen_fixture_wrapper
294+
295+
296+
def _wrap_async(func: Callable[..., Awaitable[_R]]) -> Callable[..., _R]:
297+
@functools.wraps(func)
298+
def _async_fixture_wrapper(
299+
event_loop: asyncio.AbstractEventLoop, request: SubRequest, **kwargs: Any
300+
) -> _R:
301+
async def setup() -> _R:
302+
res = await func(**_add_kwargs(func, kwargs, event_loop, request))
303+
return res
304+
305+
return event_loop.run_until_complete(setup())
306+
307+
return _async_fixture_wrapper
308+
309+
310+
_HOLDER: Set[FixtureDef] = set()
311+
312+
201313
@pytest.mark.tryfirst
202314
def pytest_pycollect_makeitem(
203315
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
@@ -212,6 +324,7 @@ def pytest_pycollect_makeitem(
212324
or _is_hypothesis_test(obj)
213325
and _hypothesis_test_wraps_coroutine(obj)
214326
):
327+
_preprocess_async_fixtures(collector.config, _HOLDER)
215328
item = pytest.Function.from_parent(collector, name=name)
216329
marker = item.get_closest_marker("asyncio")
217330
if marker is not None:
@@ -230,31 +343,6 @@ def _hypothesis_test_wraps_coroutine(function: Any) -> bool:
230343
return _is_coroutine(function.hypothesis.inner_test)
231344

232345

233-
class FixtureStripper:
234-
"""Include additional Fixture, and then strip them"""
235-
236-
EVENT_LOOP = "event_loop"
237-
238-
def __init__(self, fixturedef: FixtureDef) -> None:
239-
self.fixturedef = fixturedef
240-
self.to_strip: Set[str] = set()
241-
242-
def add(self, name: str) -> None:
243-
"""Add fixture name to fixturedef
244-
and record in to_strip list (If not previously included)"""
245-
if name in self.fixturedef.argnames:
246-
return
247-
self.fixturedef.argnames += (name,)
248-
self.to_strip.add(name)
249-
250-
def get_and_strip_from(self, name: str, data_dict: Dict[str, _T]) -> _T:
251-
"""Strip name from data, and return value"""
252-
result = data_dict[name]
253-
if name in self.to_strip:
254-
del data_dict[name]
255-
return result
256-
257-
258346
@pytest.hookimpl(trylast=True)
259347
def pytest_fixture_post_finalizer(fixturedef: FixtureDef, request: SubRequest) -> None:
260348
"""Called after fixture teardown"""
@@ -291,95 +379,6 @@ def pytest_fixture_setup(
291379
policy.set_event_loop(loop)
292380
return
293381

294-
func = fixturedef.func
295-
if not _is_coroutine_or_asyncgen(func):
296-
# Nothing to do with a regular fixture function
297-
yield
298-
return
299-
300-
config = request.node.config
301-
asyncio_mode = _get_asyncio_mode(config)
302-
303-
if not _has_explicit_asyncio_mark(func):
304-
if asyncio_mode == Mode.AUTO:
305-
# Enforce asyncio mode if 'auto'
306-
_set_explicit_asyncio_mark(func)
307-
elif asyncio_mode == Mode.LEGACY:
308-
_set_explicit_asyncio_mark(func)
309-
try:
310-
code = func.__code__
311-
except AttributeError:
312-
code = func.__func__.__code__
313-
name = (
314-
f"<fixture {func.__qualname__}, file={code.co_filename}, "
315-
f"line={code.co_firstlineno}>"
316-
)
317-
warnings.warn(
318-
LEGACY_ASYNCIO_FIXTURE.format(name=name),
319-
DeprecationWarning,
320-
)
321-
else:
322-
# asyncio_mode is STRICT,
323-
# don't handle fixtures that are not explicitly marked
324-
yield
325-
return
326-
327-
if inspect.isasyncgenfunction(func):
328-
# This is an async generator function. Wrap it accordingly.
329-
generator = func
330-
331-
fixture_stripper = FixtureStripper(fixturedef)
332-
fixture_stripper.add(FixtureStripper.EVENT_LOOP)
333-
334-
def wrapper(*args, **kwargs):
335-
loop = fixture_stripper.get_and_strip_from(
336-
FixtureStripper.EVENT_LOOP, kwargs
337-
)
338-
339-
gen_obj = generator(*args, **kwargs)
340-
341-
async def setup():
342-
res = await gen_obj.__anext__()
343-
return res
344-
345-
def finalizer():
346-
"""Yield again, to finalize."""
347-
348-
async def async_finalizer():
349-
try:
350-
await gen_obj.__anext__()
351-
except StopAsyncIteration:
352-
pass
353-
else:
354-
msg = "Async generator fixture didn't stop."
355-
msg += "Yield only once."
356-
raise ValueError(msg)
357-
358-
loop.run_until_complete(async_finalizer())
359-
360-
result = loop.run_until_complete(setup())
361-
request.addfinalizer(finalizer)
362-
return result
363-
364-
fixturedef.func = wrapper
365-
elif inspect.iscoroutinefunction(func):
366-
coro = func
367-
368-
fixture_stripper = FixtureStripper(fixturedef)
369-
fixture_stripper.add(FixtureStripper.EVENT_LOOP)
370-
371-
def wrapper(*args, **kwargs):
372-
loop = fixture_stripper.get_and_strip_from(
373-
FixtureStripper.EVENT_LOOP, kwargs
374-
)
375-
376-
async def setup():
377-
res = await coro(*args, **kwargs)
378-
return res
379-
380-
return loop.run_until_complete(setup())
381-
382-
fixturedef.func = wrapper
383382
yield
384383

385384

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
TESTS_COUNT = 0
6+
7+
8+
def teardown_module():
9+
# parametrized 2 * 2 times: 2 for 'event_loop' and 2 for 'fix'
10+
assert TESTS_COUNT == 4
11+
12+
13+
@pytest.fixture(scope="module", params=[1, 2])
14+
def event_loop(request):
15+
request.param
16+
loop = asyncio.new_event_loop()
17+
yield loop
18+
loop.close()
19+
20+
21+
@pytest.fixture(params=["a", "b"])
22+
async def fix(request):
23+
await asyncio.sleep(0)
24+
return request.param
25+
26+
27+
@pytest.mark.asyncio
28+
async def test_parametrized_loop(fix):
29+
await asyncio.sleep(0)
30+
global TESTS_COUNT
31+
TESTS_COUNT += 1

0 commit comments

Comments
 (0)