diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index 9f4e10d7..24450bf7 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -51,28 +51,35 @@ def pytest_pycollect_makeitem(collector, name, obj): @pytest.hookimpl(hookwrapper=True) def pytest_fixture_setup(fixturedef, request): """Adjust the event loop policy when an event loop is produced.""" + if fixturedef.argname == "event_loop" and 'asyncio' in request.keywords: + outcome = yield + loop = outcome.get_result() + policy = asyncio.get_event_loop_policy() + try: + old_loop = policy.get_event_loop() + except RuntimeError as exc: + if 'no current event loop' not in str(exc): + raise + old_loop = None + policy.set_event_loop(loop) + fixturedef.addfinalizer(lambda: policy.set_event_loop(old_loop)) + return + if isasyncgenfunction(fixturedef.func): # This is an async generator function. Wrap it accordingly. - f = fixturedef.func + generator = fixturedef.func - strip_event_loop = False - if 'event_loop' not in fixturedef.argnames: - fixturedef.argnames += ('event_loop', ) - strip_event_loop = True strip_request = False if 'request' not in fixturedef.argnames: fixturedef.argnames += ('request', ) strip_request = True def wrapper(*args, **kwargs): - loop = kwargs['event_loop'] request = kwargs['request'] - if strip_event_loop: - del kwargs['event_loop'] if strip_request: del kwargs['request'] - gen_obj = f(*args, **kwargs) + gen_obj = generator(*args, **kwargs) async def setup(): res = await gen_obj.__anext__() @@ -89,118 +96,69 @@ async def async_finalizer(): msg = "Async generator fixture didn't stop." msg += "Yield only once." raise ValueError(msg) - - loop.run_until_complete(async_finalizer()) + asyncio.get_event_loop().run_until_complete(async_finalizer()) request.addfinalizer(finalizer) - - return loop.run_until_complete(setup()) + return asyncio.get_event_loop().run_until_complete(setup()) fixturedef.func = wrapper - elif inspect.iscoroutinefunction(fixturedef.func): - # Just a coroutine, not an async generator. - f = fixturedef.func - - strip_event_loop = False - if 'event_loop' not in fixturedef.argnames: - fixturedef.argnames += ('event_loop', ) - strip_event_loop = True + coro = fixturedef.func def wrapper(*args, **kwargs): - loop = kwargs['event_loop'] - if strip_event_loop: - del kwargs['event_loop'] - async def setup(): - res = await f(*args, **kwargs) + res = await coro(*args, **kwargs) return res - return loop.run_until_complete(setup()) + return asyncio.get_event_loop().run_until_complete(setup()) fixturedef.func = wrapper + yield - outcome = yield - if fixturedef.argname == "event_loop" and 'asyncio' in request.keywords: - loop = outcome.get_result() - for kw in _markers_2_fixtures.keys(): - if kw not in request.keywords: - continue - policy = asyncio.get_event_loop_policy() - try: - old_loop = policy.get_event_loop() - except RuntimeError as exc: - if 'no current event loop' not in str(exc): - raise - old_loop = None - policy.set_event_loop(loop) - fixturedef.addfinalizer(lambda: policy.set_event_loop(old_loop)) - - -@pytest.mark.tryfirst +@pytest.hookimpl(tryfirst=True, hookwrapper=True) def pytest_pyfunc_call(pyfuncitem): """ Run asyncio marked test functions in an event loop instead of a normal function call. """ - for marker_name, fixture_name in _markers_2_fixtures.items(): - if marker_name in pyfuncitem.keywords \ - and not getattr(pyfuncitem.obj, 'is_hypothesis_test', False): - event_loop = pyfuncitem.funcargs[fixture_name] - - funcargs = pyfuncitem.funcargs - testargs = {arg: funcargs[arg] - for arg in pyfuncitem._fixtureinfo.argnames} - - event_loop.run_until_complete( - asyncio.ensure_future( - pyfuncitem.obj(**testargs), loop=event_loop)) - return True + if 'asyncio' in pyfuncitem.keywords: + if getattr(pyfuncitem.obj, 'is_hypothesis_test', False): + pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync( + pyfuncitem.obj.hypothesis.inner_test + ) + else: + pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj) + yield def wrap_in_sync(func): - """Return a sync wrapper around an async function.""" + """Return a sync wrapper around an async function executing it in the + current event loop.""" @functools.wraps(func) def inner(**kwargs): - loop = asyncio.get_event_loop_policy().new_event_loop() - try: - coro = func(**kwargs) - if coro is not None: - future = asyncio.ensure_future(coro, loop=loop) - loop.run_until_complete(future) - finally: - loop.close() + coro = func(**kwargs) + if coro is not None: + future = asyncio.ensure_future(coro) + asyncio.get_event_loop().run_until_complete(future) return inner def pytest_runtest_setup(item): - for marker, fixture in _markers_2_fixtures.items(): - if marker in item.keywords and fixture not in item.fixturenames: - # inject an event loop fixture for all async tests - item.fixturenames.append(fixture) - if item.get_closest_marker("asyncio") is not None: - if hasattr(item.obj, 'hypothesis'): - # If it's a Hypothesis test, we insert the wrap_in_sync decorator - item.obj.hypothesis.inner_test = wrap_in_sync( - item.obj.hypothesis.inner_test - ) - elif getattr(item.obj, 'is_hypothesis_test', False): + if 'asyncio' in item.keywords and 'event_loop' not in item.fixturenames: + # inject an event loop fixture for all async tests + item.fixturenames.append('event_loop') + if item.get_closest_marker("asyncio") is not None \ + and not getattr(item.obj, 'hypothesis', False) \ + and getattr(item.obj, 'is_hypothesis_test', False): pytest.fail( 'test function `%r` is using Hypothesis, but pytest-asyncio ' 'only works with Hypothesis 3.64.0 or later.' % item ) -# maps marker to the name of the event loop fixture that will be available -# to marked test functions -_markers_2_fixtures = { - 'asyncio': 'event_loop', -} - - @pytest.yield_fixture def event_loop(request): """Create an instance of the default event loop for each test case.""" diff --git a/tests/test_hypothesis_integration.py b/tests/test_hypothesis_integration.py index 562f4772..63c6cc74 100644 --- a/tests/test_hypothesis_integration.py +++ b/tests/test_hypothesis_integration.py @@ -1,6 +1,7 @@ """Tests for the Hypothesis integration, which wraps async functions in a sync shim for Hypothesis. """ +import asyncio import pytest @@ -25,3 +26,11 @@ async def test_mark_outer(n): async def test_mark_and_parametrize(x, y): assert x is None assert y in (1, 2) + + +@given(st.integers()) +@pytest.mark.asyncio +async def test_can_use_fixture_provided_event_loop(event_loop, n): + semaphore = asyncio.Semaphore(value=0, loop=event_loop) + event_loop.call_soon(semaphore.release) + await semaphore.acquire()