diff --git a/setup.cfg b/setup.cfg index 16efa184..68b8e2f0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,9 @@ install_requires = typing >= 3.5; python_version == "3.4" [options.extras_require] -test = pytest >= 3.1.0 +test = + pytest >= 3.1.0 + typing_extensions >= 3.5 [flake8] max-line-length = 99 diff --git a/sphinx_autodoc_typehints.py b/sphinx_autodoc_typehints.py index c2540fb3..e7f8f4a2 100644 --- a/sphinx_autodoc_typehints.py +++ b/sphinx_autodoc_typehints.py @@ -5,6 +5,11 @@ from sphinx.util import logging from sphinx.util.inspect import Signature +try: + from typing_extensions import Protocol +except ImportError: + Protocol = None + try: from inspect import unwrap except ImportError: @@ -47,7 +52,8 @@ def format_annotation(annotation): if inspect.isclass(getattr(annotation, '__origin__', None)): annotation_cls = annotation.__origin__ try: - if Generic in annotation_cls.mro(): + mro = annotation_cls.mro() + if Generic in mro or (Protocol and Protocol in mro): module = annotation_cls.__module__ except TypeError: pass # annotation_cls was either the "type" object or typing.Type @@ -116,7 +122,8 @@ def format_annotation(annotation): annotation_cls = annotation.__origin__ extra = '' - if Generic in annotation_cls.mro(): + mro = annotation_cls.mro() + if Generic in mro or (Protocol and Protocol in mro): params = (getattr(annotation, '__parameters__', None) or getattr(annotation, '__args__', None)) if params: diff --git a/tests/test_sphinx_autodoc_typehints.py b/tests/test_sphinx_autodoc_typehints.py index dce96866..cba630f8 100644 --- a/tests/test_sphinx_autodoc_typehints.py +++ b/tests/test_sphinx_autodoc_typehints.py @@ -5,6 +5,8 @@ from typing import ( Any, AnyStr, Callable, Dict, Generic, Mapping, Optional, Pattern, Tuple, TypeVar, Union, Type) +from typing_extensions import Protocol + from sphinx_autodoc_typehints import format_annotation, process_docstring T = TypeVar('T') @@ -25,6 +27,14 @@ class C(B[str]): pass +class D(Protocol): + pass + + +class E(Protocol[T]): + pass + + class Slotted: __slots__ = () @@ -76,7 +86,10 @@ class Slotted: (A, ':py:class:`~%s.A`' % __name__), (B, ':py:class:`~%s.B`\\[\\~T]' % __name__), (B[int], ':py:class:`~%s.B`\\[:py:class:`int`]' % __name__), - (C, ':py:class:`~%s.C`' % __name__) + (C, ':py:class:`~%s.C`' % __name__), + (D, ':py:class:`~%s.D`' % __name__), + (E, ':py:class:`~%s.E`\\[\\~T]' % __name__), + (E[int], ':py:class:`~%s.E`\\[:py:class:`int`]' % __name__) ]) def test_format_annotation(annotation, expected_result): result = format_annotation(annotation)