Skip to content

Commit 640bb2e

Browse files
committed
pycode: Detect @overload decorators
1 parent a59f83b commit 640bb2e

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

sphinx/pycode/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tokenize
1313
import warnings
1414
from importlib import import_module
15+
from inspect import Signature
1516
from io import StringIO
1617
from os import path
1718
from typing import Any, Dict, IO, List, Tuple, Optional
@@ -145,6 +146,7 @@ def __init__(self, source: IO, modname: str, srcname: str, decoded: bool = False
145146
self.annotations = None # type: Dict[Tuple[str, str], str]
146147
self.attr_docs = None # type: Dict[Tuple[str, str], List[str]]
147148
self.finals = None # type: List[str]
149+
self.overloads = None # type: Dict[str, List[Signature]]
148150
self.tagorder = None # type: Dict[str, int]
149151
self.tags = None # type: Dict[str, Tuple[str, int, int]]
150152

@@ -163,6 +165,7 @@ def parse(self) -> None:
163165

164166
self.annotations = parser.annotations
165167
self.finals = parser.finals
168+
self.overloads = parser.overloads
166169
self.tags = parser.definitions
167170
self.tagorder = parser.deforders
168171
except Exception as exc:

sphinx/pycode/parser.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
import re
1313
import sys
1414
import tokenize
15+
from inspect import Signature
1516
from token import NAME, NEWLINE, INDENT, DEDENT, NUMBER, OP, STRING
1617
from tokenize import COMMENT, NL
1718
from typing import Any, Dict, List, Optional, Tuple
1819

1920
from sphinx.pycode.ast import ast # for py37 or older
2021
from sphinx.pycode.ast import parse, unparse
22+
from sphinx.util.inspect import signature_from_ast
2123

2224

2325
comment_re = re.compile('^\\s*#: ?(.*)\r?\n?$')
@@ -232,8 +234,10 @@ def __init__(self, buffers: List[str], encoding: str) -> None:
232234
self.previous = None # type: ast.AST
233235
self.deforders = {} # type: Dict[str, int]
234236
self.finals = [] # type: List[str]
237+
self.overloads = {} # type: Dict[str, List[Signature]]
235238
self.typing = None # type: str
236239
self.typing_final = None # type: str
240+
self.typing_overload = None # type: str
237241
super().__init__()
238242

239243
def get_qualname_for(self, name: str) -> Optional[List[str]]:
@@ -257,6 +261,12 @@ def add_final_entry(self, name: str) -> None:
257261
if qualname:
258262
self.finals.append(".".join(qualname))
259263

264+
def add_overload_entry(self, func: ast.FunctionDef) -> None:
265+
qualname = self.get_qualname_for(func.name)
266+
if qualname:
267+
overloads = self.overloads.setdefault(".".join(qualname), [])
268+
overloads.append(signature_from_ast(func))
269+
260270
def add_variable_comment(self, name: str, comment: str) -> None:
261271
qualname = self.get_qualname_for(name)
262272
if qualname:
@@ -285,6 +295,22 @@ def is_final(self, decorators: List[ast.expr]) -> bool:
285295

286296
return False
287297

298+
def is_overload(self, decorators: List[ast.expr]) -> bool:
299+
overload = []
300+
if self.typing:
301+
overload.append('%s.overload' % self.typing)
302+
if self.typing_overload:
303+
overload.append(self.typing_overload)
304+
305+
for decorator in decorators:
306+
try:
307+
if unparse(decorator) in overload:
308+
return True
309+
except NotImplementedError:
310+
pass
311+
312+
return False
313+
288314
def get_self(self) -> ast.arg:
289315
"""Returns the name of first argument if in function."""
290316
if self.current_function and self.current_function.args.args:
@@ -310,6 +336,8 @@ def visit_Import(self, node: ast.Import) -> None:
310336
self.typing = name.asname or name.name
311337
elif name.name == 'typing.final':
312338
self.typing_final = name.asname or name.name
339+
elif name.name == 'typing.overload':
340+
self.typing_overload = name.asname or name.name
313341

314342
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
315343
"""Handles Import node and record it to definition orders."""
@@ -318,6 +346,8 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
318346

319347
if node.module == 'typing' and name.name == 'final':
320348
self.typing_final = name.asname or name.name
349+
elif node.module == 'typing' and name.name == 'overload':
350+
self.typing_overload = name.asname or name.name
321351

322352
def visit_Assign(self, node: ast.Assign) -> None:
323353
"""Handles Assign node and pick up a variable comment."""
@@ -417,6 +447,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
417447
self.add_entry(node.name) # should be called before setting self.current_function
418448
if self.is_final(node.decorator_list):
419449
self.add_final_entry(node.name)
450+
if self.is_overload(node.decorator_list):
451+
self.add_overload_entry(node)
420452
self.context.append(node.name)
421453
self.current_function = node
422454
for child in node.body:
@@ -518,6 +550,7 @@ def __init__(self, code: str, encoding: str = 'utf-8') -> None:
518550
self.deforders = {} # type: Dict[str, int]
519551
self.definitions = {} # type: Dict[str, Tuple[str, int, int]]
520552
self.finals = [] # type: List[str]
553+
self.overloads = {} # type: Dict[str, List[Signature]]
521554

522555
def parse(self) -> None:
523556
"""Parse the source code."""
@@ -533,6 +566,7 @@ def parse_comments(self) -> None:
533566
self.comments = picker.comments
534567
self.deforders = picker.deforders
535568
self.finals = picker.finals
569+
self.overloads = picker.overloads
536570

537571
def parse_definition(self) -> None:
538572
"""Parse the location of definitions from the code."""

tests/test_pycode_parser.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pytest
1414

1515
from sphinx.pycode.parser import Parser
16+
from sphinx.util.inspect import signature_from_str
1617

1718

1819
def test_comment_picker_basic():
@@ -452,3 +453,80 @@ def test_typing_final_not_imported():
452453
parser = Parser(source)
453454
parser.parse()
454455
assert parser.finals == []
456+
457+
458+
def test_typing_overload():
459+
source = ('import typing\n'
460+
'\n'
461+
'@typing.overload\n'
462+
'def func(x: int, y: int) -> int: pass\n'
463+
'\n'
464+
'@typing.overload\n'
465+
'def func(x: str, y: str) -> str: pass\n'
466+
'\n'
467+
'def func(x, y): pass\n')
468+
parser = Parser(source)
469+
parser.parse()
470+
assert parser.overloads == {'func': [signature_from_str('(x: int, y: int) -> int'),
471+
signature_from_str('(x: str, y: str) -> str')]}
472+
473+
474+
def test_typing_overload_from_import():
475+
source = ('from typing import overload\n'
476+
'\n'
477+
'@overload\n'
478+
'def func(x: int, y: int) -> int: pass\n'
479+
'\n'
480+
'@overload\n'
481+
'def func(x: str, y: str) -> str: pass\n'
482+
'\n'
483+
'def func(x, y): pass\n')
484+
parser = Parser(source)
485+
parser.parse()
486+
assert parser.overloads == {'func': [signature_from_str('(x: int, y: int) -> int'),
487+
signature_from_str('(x: str, y: str) -> str')]}
488+
489+
490+
def test_typing_overload_import_as():
491+
source = ('import typing as foo\n'
492+
'\n'
493+
'@foo.overload\n'
494+
'def func(x: int, y: int) -> int: pass\n'
495+
'\n'
496+
'@foo.overload\n'
497+
'def func(x: str, y: str) -> str: pass\n'
498+
'\n'
499+
'def func(x, y): pass\n')
500+
parser = Parser(source)
501+
parser.parse()
502+
assert parser.overloads == {'func': [signature_from_str('(x: int, y: int) -> int'),
503+
signature_from_str('(x: str, y: str) -> str')]}
504+
505+
506+
def test_typing_overload_from_import_as():
507+
source = ('from typing import overload as bar\n'
508+
'\n'
509+
'@bar\n'
510+
'def func(x: int, y: int) -> int: pass\n'
511+
'\n'
512+
'@bar\n'
513+
'def func(x: str, y: str) -> str: pass\n'
514+
'\n'
515+
'def func(x, y): pass\n')
516+
parser = Parser(source)
517+
parser.parse()
518+
assert parser.overloads == {'func': [signature_from_str('(x: int, y: int) -> int'),
519+
signature_from_str('(x: str, y: str) -> str')]}
520+
521+
522+
def test_typing_overload_not_imported():
523+
source = ('@typing.final\n'
524+
'def func(x: int, y: int) -> int: pass\n'
525+
'\n'
526+
'@typing.final\n'
527+
'def func(x: str, y: str) -> str: pass\n'
528+
'\n'
529+
'def func(x, y): pass\n')
530+
parser = Parser(source)
531+
parser.parse()
532+
assert parser.overloads == {}

0 commit comments

Comments
 (0)