diff --git a/docs/source/how-to-guides/index.rst b/docs/source/how-to-guides/index.rst index 5bcb3be7..71567aaf 100644 --- a/docs/source/how-to-guides/index.rst +++ b/docs/source/how-to-guides/index.rst @@ -7,5 +7,6 @@ How-To Guides multiple_loops uvloop + test_item_is_async This section of the documentation provides code snippets and recipes to accomplish specific tasks with pytest-asyncio. diff --git a/docs/source/how-to-guides/test_item_is_async.rst b/docs/source/how-to-guides/test_item_is_async.rst new file mode 100644 index 00000000..a9ea5d40 --- /dev/null +++ b/docs/source/how-to-guides/test_item_is_async.rst @@ -0,0 +1,7 @@ +======================================= +How to tell if a test function is async +======================================= +Use ``pytest_asyncio.is_async_item`` to determine if a test item is asynchronous and managed by pytest-asyncio. + +.. include:: test_item_is_async_example.py + :code: python diff --git a/docs/source/how-to-guides/test_item_is_async_example.py b/docs/source/how-to-guides/test_item_is_async_example.py new file mode 100644 index 00000000..31b44193 --- /dev/null +++ b/docs/source/how-to-guides/test_item_is_async_example.py @@ -0,0 +1,7 @@ +from pytest_asyncio import is_async_test + + +def pytest_collection_modifyitems(items): + for item in items: + if is_async_test(item): + pass diff --git a/docs/source/reference/changelog.rst b/docs/source/reference/changelog.rst index d902ff06..504c58f7 100644 --- a/docs/source/reference/changelog.rst +++ b/docs/source/reference/changelog.rst @@ -9,6 +9,7 @@ Changes are non-breaking, unless you upgrade from v0.22. - BREAKING: The *asyncio_event_loop* mark has been removed. Event loops with class, module, package, and session scopes can be requested via the *scope* keyword argument to the _asyncio_ mark. - Introduces the *event_loop_policy* fixture which allows testing with non-default or multiple event loops `#662 `_ +- Introduces ``pytest_asyncio.is_async_test`` which returns whether a test item is managed by pytest-asyncio `#376 `_ - Removes pytest-trio from the test dependencies `#620 `_ 0.22.0 (2023-10-31) diff --git a/docs/source/reference/functions.rst b/docs/source/reference/functions.rst new file mode 100644 index 00000000..fcd531c2 --- /dev/null +++ b/docs/source/reference/functions.rst @@ -0,0 +1,9 @@ +========= +Functions +========= + +is_async_test +============= +Returns whether a specific pytest Item is an asynchronous test managed by pytest-asyncio. + +This function is intended to be used in pytest hooks or by plugins that depend on pytest-asyncio. diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index 5fdc2724..b24c6e9c 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -7,6 +7,7 @@ Reference configuration fixtures/index + functions markers/index decorators/index changelog diff --git a/pytest_asyncio/__init__.py b/pytest_asyncio/__init__.py index 1bc2811d..95046981 100644 --- a/pytest_asyncio/__init__.py +++ b/pytest_asyncio/__init__.py @@ -1,5 +1,5 @@ """The main point for importing pytest-asyncio items.""" from ._version import version as __version__ # noqa -from .plugin import fixture +from .plugin import fixture, is_async_test -__all__ = ("fixture",) +__all__ = ("fixture", "is_async_test") diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index 4f9ed217..892d8237 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -21,6 +21,7 @@ Literal, Optional, Set, + Type, TypeVar, Union, overload, @@ -365,18 +366,19 @@ class PytestAsyncioFunction(Function): """Base class for all test functions managed by pytest-asyncio.""" @classmethod - def substitute(cls, item: Function, /) -> Function: + def item_subclass_for( + cls, item: Function, / + ) -> Union[Type["PytestAsyncioFunction"], None]: """ - Returns a PytestAsyncioFunction if there is an implementation that can handle - the specified function item. + Returns a subclass of PytestAsyncioFunction if there is a specialized subclass + for the specified function item. - If no implementation of PytestAsyncioFunction can handle the specified item, - the item is returned unchanged. + Return None if no specialized subclass exists for the specified item. """ for subclass in cls.__subclasses__(): if subclass._can_substitute(item): - return subclass._from_function(item) - return item + return subclass + return None @classmethod def _from_function(cls, function: Function, /) -> Function: @@ -384,6 +386,7 @@ def _from_function(cls, function: Function, /) -> Function: Instantiates this specific PytestAsyncioFunction type from the specified Function item. """ + assert function.get_closest_marker("asyncio") subclass_instance = cls.from_parent( function.parent, name=function.name, @@ -393,6 +396,7 @@ def _from_function(cls, function: Function, /) -> Function: keywords=function.keywords, originalname=function.originalname, ) + subclass_instance.own_markers.extend(function.own_markers) subclassed_function_signature = inspect.signature(subclass_instance.obj) if "event_loop" in subclassed_function_signature.parameters: subclass_instance.warn( @@ -419,11 +423,10 @@ def _can_substitute(item: Function) -> bool: return asyncio.iscoroutinefunction(func) def runtest(self) -> None: - if self.get_closest_marker("asyncio"): - self.obj = wrap_in_sync( - # https://github.com/pytest-dev/pytest-asyncio/issues/596 - self.obj, # type: ignore[has-type] - ) + self.obj = wrap_in_sync( + # https://github.com/pytest-dev/pytest-asyncio/issues/596 + self.obj, # type: ignore[has-type] + ) super().runtest() @@ -463,11 +466,10 @@ def _can_substitute(item: Function) -> bool: ) def runtest(self) -> None: - if self.get_closest_marker("asyncio"): - self.obj = wrap_in_sync( - # https://github.com/pytest-dev/pytest-asyncio/issues/596 - self.obj, # type: ignore[has-type] - ) + self.obj = wrap_in_sync( + # https://github.com/pytest-dev/pytest-asyncio/issues/596 + self.obj, # type: ignore[has-type] + ) super().runtest() @@ -485,10 +487,9 @@ def _can_substitute(item: Function) -> bool: ) and asyncio.iscoroutinefunction(func.hypothesis.inner_test) def runtest(self) -> None: - if self.get_closest_marker("asyncio"): - self.obj.hypothesis.inner_test = wrap_in_sync( - self.obj.hypothesis.inner_test, - ) + self.obj.hypothesis.inner_test = wrap_in_sync( + self.obj.hypothesis.inner_test, + ) super().runtest() @@ -535,9 +536,15 @@ def pytest_pycollect_makeitem_convert_async_functions_to_subclass( for node in node_iterator: updated_item = node if isinstance(node, Function): - updated_item = PytestAsyncioFunction.substitute(node) + specialized_item_class = PytestAsyncioFunction.item_subclass_for(node) + if specialized_item_class: + if _get_asyncio_mode( + node.config + ) == Mode.AUTO and not node.get_closest_marker("asyncio"): + node.add_marker("asyncio") + if node.get_closest_marker("asyncio"): + updated_item = specialized_item_class._from_function(node) updated_node_collection.append(updated_item) - hook_result.force_result(updated_node_collection) @@ -644,28 +651,6 @@ def _temporary_event_loop_policy(policy: AbstractEventLoopPolicy) -> Iterator[No asyncio.set_event_loop(old_loop) -def pytest_collection_modifyitems( - session: Session, config: Config, items: List[Item] -) -> None: - """ - Marks collected async test items as `asyncio` tests. - - The mark is only applied in `AUTO` mode. It is applied to: - - - coroutines and async generators - - Hypothesis tests wrapping coroutines - - staticmethods wrapping coroutines - - """ - if _get_asyncio_mode(config) != Mode.AUTO: - return - for item in items: - if isinstance(item, PytestAsyncioFunction) and not item.get_closest_marker( - "asyncio" - ): - item.add_marker("asyncio") - - _REDEFINED_EVENT_LOOP_FIXTURE_WARNING = dedent( """\ The event_loop fixture provided by pytest-asyncio has been redefined in @@ -978,6 +963,11 @@ def event_loop_policy() -> AbstractEventLoopPolicy: return asyncio.get_event_loop_policy() +def is_async_test(item: Item) -> bool: + """Returns whether a test item is a pytest-asyncio test""" + return isinstance(item, PytestAsyncioFunction) + + def _unused_port(socket_type: int) -> int: """Find an unused localhost port from 1024-65535 and return it.""" with contextlib.closing(socket.socket(type=socket_type)) as sock: diff --git a/tests/test_is_async_test.py b/tests/test_is_async_test.py new file mode 100644 index 00000000..512243b3 --- /dev/null +++ b/tests/test_is_async_test.py @@ -0,0 +1,105 @@ +from textwrap import dedent + +import pytest +from pytest import Pytester + + +def test_returns_false_for_sync_item(pytester: Pytester): + pytester.makepyfile( + dedent( + """\ + import pytest + import pytest_asyncio + + def test_sync(): + pass + + def pytest_collection_modifyitems(items): + async_tests = [ + item + for item in items + if pytest_asyncio.is_async_test(item) + ] + assert len(async_tests) == 0 + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) + + +def test_returns_true_for_marked_coroutine_item_in_strict_mode(pytester: Pytester): + pytester.makepyfile( + dedent( + """\ + import pytest + import pytest_asyncio + + @pytest.mark.asyncio + async def test_coro(): + pass + + def pytest_collection_modifyitems(items): + async_tests = [ + item + for item in items + if pytest_asyncio.is_async_test(item) + ] + assert len(async_tests) == 1 + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) + + +def test_returns_false_for_unmarked_coroutine_item_in_strict_mode(pytester: Pytester): + pytester.makepyfile( + dedent( + """\ + import pytest + import pytest_asyncio + + async def test_coro(): + pass + + def pytest_collection_modifyitems(items): + async_tests = [ + item + for item in items + if pytest_asyncio.is_async_test(item) + ] + assert len(async_tests) == 0 + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + if pytest.version_tuple < (7, 2): + # Probably related to https://github.com/pytest-dev/pytest/pull/10012 + result.assert_outcomes(failed=1) + else: + result.assert_outcomes(skipped=1) + + +def test_returns_true_for_unmarked_coroutine_item_in_auto_mode(pytester: Pytester): + pytester.makepyfile( + dedent( + """\ + import pytest + import pytest_asyncio + + async def test_coro(): + pass + + def pytest_collection_modifyitems(items): + async_tests = [ + item + for item in items + if pytest_asyncio.is_async_test(item) + ] + assert len(async_tests) == 1 + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=auto") + result.assert_outcomes(passed=1)