Skip to content

Commit 44eb1f5

Browse files
authored
Merge pull request #6311 from bluetech/type-annotations-10
Some type annotation & check_untyped_defs fixes
2 parents 4fb9cc3 + 3392be3 commit 44eb1f5

File tree

8 files changed

+324
-214
lines changed

8 files changed

+324
-214
lines changed

src/_pytest/_code/code.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __ne__(self, other):
6767
return not self == other
6868

6969
@property
70-
def path(self):
70+
def path(self) -> Union[py.path.local, str]:
7171
""" return a path object pointing to source code (note that it
7272
might not point to an actually existing file). """
7373
try:
@@ -335,7 +335,7 @@ def cut(
335335
(path is None or codepath == path)
336336
and (
337337
excludepath is None
338-
or not hasattr(codepath, "relto")
338+
or not isinstance(codepath, py.path.local)
339339
or not codepath.relto(excludepath)
340340
)
341341
and (lineno is None or x.lineno == lineno)

src/_pytest/_code/source.py

+70-12
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import textwrap
66
import tokenize
77
import warnings
8-
from ast import PyCF_ONLY_AST as _AST_FLAG
98
from bisect import bisect_right
9+
from types import CodeType
1010
from types import FrameType
1111
from typing import Iterator
1212
from typing import List
@@ -18,6 +18,10 @@
1818
import py
1919

2020
from _pytest.compat import overload
21+
from _pytest.compat import TYPE_CHECKING
22+
23+
if TYPE_CHECKING:
24+
from typing_extensions import Literal
2125

2226

2327
class Source:
@@ -121,7 +125,7 @@ def getstatement(self, lineno: int) -> "Source":
121125
start, end = self.getstatementrange(lineno)
122126
return self[start:end]
123127

124-
def getstatementrange(self, lineno: int):
128+
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
125129
""" return (start, end) tuple which spans the minimal
126130
statement region which containing the given lineno.
127131
"""
@@ -159,14 +163,36 @@ def isparseable(self, deindent: bool = True) -> bool:
159163
def __str__(self) -> str:
160164
return "\n".join(self.lines)
161165

166+
@overload
162167
def compile(
163168
self,
164-
filename=None,
165-
mode="exec",
169+
filename: Optional[str] = ...,
170+
mode: str = ...,
171+
flag: "Literal[0]" = ...,
172+
dont_inherit: int = ...,
173+
_genframe: Optional[FrameType] = ...,
174+
) -> CodeType:
175+
raise NotImplementedError()
176+
177+
@overload # noqa: F811
178+
def compile( # noqa: F811
179+
self,
180+
filename: Optional[str] = ...,
181+
mode: str = ...,
182+
flag: int = ...,
183+
dont_inherit: int = ...,
184+
_genframe: Optional[FrameType] = ...,
185+
) -> Union[CodeType, ast.AST]:
186+
raise NotImplementedError()
187+
188+
def compile( # noqa: F811
189+
self,
190+
filename: Optional[str] = None,
191+
mode: str = "exec",
166192
flag: int = 0,
167193
dont_inherit: int = 0,
168194
_genframe: Optional[FrameType] = None,
169-
):
195+
) -> Union[CodeType, ast.AST]:
170196
""" return compiled code object. if filename is None
171197
invent an artificial filename which displays
172198
the source/line position of the caller frame.
@@ -196,8 +222,10 @@ def compile(
196222
newex.text = ex.text
197223
raise newex
198224
else:
199-
if flag & _AST_FLAG:
225+
if flag & ast.PyCF_ONLY_AST:
226+
assert isinstance(co, ast.AST)
200227
return co
228+
assert isinstance(co, CodeType)
201229
lines = [(x + "\n") for x in self.lines]
202230
# Type ignored because linecache.cache is private.
203231
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
@@ -209,22 +237,52 @@ def compile(
209237
#
210238

211239

212-
def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
240+
@overload
241+
def compile_(
242+
source: Union[str, bytes, ast.mod, ast.AST],
243+
filename: Optional[str] = ...,
244+
mode: str = ...,
245+
flags: "Literal[0]" = ...,
246+
dont_inherit: int = ...,
247+
) -> CodeType:
248+
raise NotImplementedError()
249+
250+
251+
@overload # noqa: F811
252+
def compile_( # noqa: F811
253+
source: Union[str, bytes, ast.mod, ast.AST],
254+
filename: Optional[str] = ...,
255+
mode: str = ...,
256+
flags: int = ...,
257+
dont_inherit: int = ...,
258+
) -> Union[CodeType, ast.AST]:
259+
raise NotImplementedError()
260+
261+
262+
def compile_( # noqa: F811
263+
source: Union[str, bytes, ast.mod, ast.AST],
264+
filename: Optional[str] = None,
265+
mode: str = "exec",
266+
flags: int = 0,
267+
dont_inherit: int = 0,
268+
) -> Union[CodeType, ast.AST]:
213269
""" compile the given source to a raw code object,
214270
and maintain an internal cache which allows later
215271
retrieval of the source code for the code object
216272
and any recursively created code objects.
217273
"""
218274
if isinstance(source, ast.AST):
219275
# XXX should Source support having AST?
220-
return compile(source, filename, mode, flags, dont_inherit)
276+
assert filename is not None
277+
co = compile(source, filename, mode, flags, dont_inherit)
278+
assert isinstance(co, (CodeType, ast.AST))
279+
return co
221280
_genframe = sys._getframe(1) # the caller
222281
s = Source(source)
223-
co = s.compile(filename, mode, flags, _genframe=_genframe)
224-
return co
282+
return s.compile(filename, mode, flags, _genframe=_genframe)
225283

226284

227-
def getfslineno(obj):
285+
def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]:
228286
""" Return source location (path, lineno) for the given object.
229287
If the source cannot be determined return ("", -1).
230288
@@ -321,7 +379,7 @@ def getstatementrange_ast(
321379
# don't produce duplicate warnings when compiling source to find ast
322380
with warnings.catch_warnings():
323381
warnings.simplefilter("ignore")
324-
astnode = compile(content, "source", "exec", _AST_FLAG)
382+
astnode = ast.parse(content, "source", "exec")
325383

326384
start, end = get_statement_startend2(lineno, astnode)
327385
# we need to correct the end:

src/_pytest/recwarn.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def deprecated_call(func=None, *args, **kwargs):
5757

5858
@overload
5959
def warns(
60-
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
60+
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
6161
*,
6262
match: "Optional[Union[str, Pattern]]" = ...
6363
) -> "WarningsChecker":
@@ -66,7 +66,7 @@ def warns(
6666

6767
@overload # noqa: F811
6868
def warns( # noqa: F811
69-
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
69+
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
7070
func: Callable,
7171
*args: Any,
7272
match: Optional[Union[str, "Pattern"]] = ...,
@@ -76,7 +76,7 @@ def warns( # noqa: F811
7676

7777

7878
def warns( # noqa: F811
79-
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
79+
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
8080
*args: Any,
8181
match: Optional[Union[str, "Pattern"]] = None,
8282
**kwargs: Any

src/_pytest/reports.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from io import StringIO
22
from pprint import pprint
3+
from typing import Any
34
from typing import List
45
from typing import Optional
56
from typing import Tuple
@@ -17,6 +18,7 @@
1718
from _pytest._code.code import ReprLocals
1819
from _pytest._code.code import ReprTraceback
1920
from _pytest._code.code import TerminalRepr
21+
from _pytest.compat import TYPE_CHECKING
2022
from _pytest.nodes import Node
2123
from _pytest.outcomes import skip
2224
from _pytest.pathlib import Path
@@ -41,9 +43,14 @@ class BaseReport:
4143
sections = [] # type: List[Tuple[str, str]]
4244
nodeid = None # type: str
4345

44-
def __init__(self, **kw):
46+
def __init__(self, **kw: Any) -> None:
4547
self.__dict__.update(kw)
4648

49+
if TYPE_CHECKING:
50+
# Can have arbitrary fields given to __init__().
51+
def __getattr__(self, key: str) -> Any:
52+
raise NotImplementedError()
53+
4754
def toterminal(self, out) -> None:
4855
if hasattr(self, "node"):
4956
out.line(getslaveinfoline(self.node)) # type: ignore

testing/code/test_source.py

+16
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import ast
55
import inspect
66
import sys
7+
from types import CodeType
78
from typing import Any
89
from typing import Dict
910
from typing import Optional
1011

12+
import py
13+
1114
import _pytest._code
1215
import pytest
1316
from _pytest._code import Source
@@ -147,6 +150,10 @@ def test_getrange(self) -> None:
147150
assert len(x.lines) == 2
148151
assert str(x) == "def f(x):\n pass"
149152

153+
def test_getrange_step_not_supported(self) -> None:
154+
with pytest.raises(IndexError, match=r"step"):
155+
self.source[::2]
156+
150157
def test_getline(self) -> None:
151158
x = self.source[0]
152159
assert x == "def f(x):"
@@ -449,6 +456,14 @@ def test_idem_compile_and_getsource() -> None:
449456
assert src == expected
450457

451458

459+
def test_compile_ast() -> None:
460+
# We don't necessarily want to support this.
461+
# This test was added just for coverage.
462+
stmt = ast.parse("def x(): pass")
463+
co = _pytest._code.compile(stmt, filename="foo.py")
464+
assert isinstance(co, CodeType)
465+
466+
452467
def test_findsource_fallback() -> None:
453468
from _pytest._code.source import findsource
454469

@@ -488,6 +503,7 @@ def f(x) -> None:
488503

489504
fspath, lineno = getfslineno(f)
490505

506+
assert isinstance(fspath, py.path.local)
491507
assert fspath.basename == "test_source.py"
492508
assert lineno == f.__code__.co_firstlineno - 1 # see findsource
493509

0 commit comments

Comments
 (0)