5
5
import textwrap
6
6
import tokenize
7
7
import warnings
8
- from ast import PyCF_ONLY_AST as _AST_FLAG
9
8
from bisect import bisect_right
9
+ from types import CodeType
10
10
from types import FrameType
11
11
from typing import Iterator
12
12
from typing import List
18
18
import py
19
19
20
20
from _pytest .compat import overload
21
+ from _pytest .compat import TYPE_CHECKING
22
+
23
+ if TYPE_CHECKING :
24
+ from typing_extensions import Literal
21
25
22
26
23
27
class Source :
@@ -121,7 +125,7 @@ def getstatement(self, lineno: int) -> "Source":
121
125
start , end = self .getstatementrange (lineno )
122
126
return self [start :end ]
123
127
124
- def getstatementrange (self , lineno : int ):
128
+ def getstatementrange (self , lineno : int ) -> Tuple [ int , int ] :
125
129
""" return (start, end) tuple which spans the minimal
126
130
statement region which containing the given lineno.
127
131
"""
@@ -159,14 +163,36 @@ def isparseable(self, deindent: bool = True) -> bool:
159
163
def __str__ (self ) -> str :
160
164
return "\n " .join (self .lines )
161
165
166
+ @overload
162
167
def compile (
163
168
self ,
164
- filename = None ,
165
- mode = "exec" ,
169
+ filename : Optional [str ] = ...,
170
+ mode : str = ...,
171
+ flag : "Literal[0]" = ...,
172
+ dont_inherit : int = ...,
173
+ _genframe : Optional [FrameType ] = ...,
174
+ ) -> CodeType :
175
+ raise NotImplementedError ()
176
+
177
+ @overload # noqa: F811
178
+ def compile ( # noqa: F811
179
+ self ,
180
+ filename : Optional [str ] = ...,
181
+ mode : str = ...,
182
+ flag : int = ...,
183
+ dont_inherit : int = ...,
184
+ _genframe : Optional [FrameType ] = ...,
185
+ ) -> Union [CodeType , ast .AST ]:
186
+ raise NotImplementedError ()
187
+
188
+ def compile ( # noqa: F811
189
+ self ,
190
+ filename : Optional [str ] = None ,
191
+ mode : str = "exec" ,
166
192
flag : int = 0 ,
167
193
dont_inherit : int = 0 ,
168
194
_genframe : Optional [FrameType ] = None ,
169
- ):
195
+ ) -> Union [ CodeType , ast . AST ] :
170
196
""" return compiled code object. if filename is None
171
197
invent an artificial filename which displays
172
198
the source/line position of the caller frame.
@@ -196,8 +222,10 @@ def compile(
196
222
newex .text = ex .text
197
223
raise newex
198
224
else :
199
- if flag & _AST_FLAG :
225
+ if flag & ast .PyCF_ONLY_AST :
226
+ assert isinstance (co , ast .AST )
200
227
return co
228
+ assert isinstance (co , CodeType )
201
229
lines = [(x + "\n " ) for x in self .lines ]
202
230
# Type ignored because linecache.cache is private.
203
231
linecache .cache [filename ] = (1 , None , lines , filename ) # type: ignore
@@ -209,22 +237,52 @@ def compile(
209
237
#
210
238
211
239
212
- def compile_ (source , filename = None , mode = "exec" , flags : int = 0 , dont_inherit : int = 0 ):
240
+ @overload
241
+ def compile_ (
242
+ source : Union [str , bytes , ast .mod , ast .AST ],
243
+ filename : Optional [str ] = ...,
244
+ mode : str = ...,
245
+ flags : "Literal[0]" = ...,
246
+ dont_inherit : int = ...,
247
+ ) -> CodeType :
248
+ raise NotImplementedError ()
249
+
250
+
251
+ @overload # noqa: F811
252
+ def compile_ ( # noqa: F811
253
+ source : Union [str , bytes , ast .mod , ast .AST ],
254
+ filename : Optional [str ] = ...,
255
+ mode : str = ...,
256
+ flags : int = ...,
257
+ dont_inherit : int = ...,
258
+ ) -> Union [CodeType , ast .AST ]:
259
+ raise NotImplementedError ()
260
+
261
+
262
+ def compile_ ( # noqa: F811
263
+ source : Union [str , bytes , ast .mod , ast .AST ],
264
+ filename : Optional [str ] = None ,
265
+ mode : str = "exec" ,
266
+ flags : int = 0 ,
267
+ dont_inherit : int = 0 ,
268
+ ) -> Union [CodeType , ast .AST ]:
213
269
""" compile the given source to a raw code object,
214
270
and maintain an internal cache which allows later
215
271
retrieval of the source code for the code object
216
272
and any recursively created code objects.
217
273
"""
218
274
if isinstance (source , ast .AST ):
219
275
# XXX should Source support having AST?
220
- return compile (source , filename , mode , flags , dont_inherit )
276
+ assert filename is not None
277
+ co = compile (source , filename , mode , flags , dont_inherit )
278
+ assert isinstance (co , (CodeType , ast .AST ))
279
+ return co
221
280
_genframe = sys ._getframe (1 ) # the caller
222
281
s = Source (source )
223
- co = s .compile (filename , mode , flags , _genframe = _genframe )
224
- return co
282
+ return s .compile (filename , mode , flags , _genframe = _genframe )
225
283
226
284
227
- def getfslineno (obj ):
285
+ def getfslineno (obj ) -> Tuple [ Union [ str , py . path . local ], int ] :
228
286
""" Return source location (path, lineno) for the given object.
229
287
If the source cannot be determined return ("", -1).
230
288
@@ -321,7 +379,7 @@ def getstatementrange_ast(
321
379
# don't produce duplicate warnings when compiling source to find ast
322
380
with warnings .catch_warnings ():
323
381
warnings .simplefilter ("ignore" )
324
- astnode = compile (content , "source" , "exec" , _AST_FLAG )
382
+ astnode = ast . parse (content , "source" , "exec" )
325
383
326
384
start , end = get_statement_startend2 (lineno , astnode )
327
385
# we need to correct the end:
0 commit comments