@@ -21,7 +21,7 @@ class Context:
21
21
lines : set [int ]
22
22
23
23
24
- class RegionFinder ( ast . NodeVisitor ) :
24
+ class RegionFinder :
25
25
"""An ast visitor that will find and track regions of code.
26
26
27
27
Functions and classes are tracked by name. Results are in the .regions
@@ -34,13 +34,27 @@ def __init__(self) -> None:
34
34
35
35
def parse_source (self , source : str ) -> None :
36
36
"""Parse `source` and walk the ast to populate the .regions attribute."""
37
- self .visit (ast .parse (source ))
37
+ self .handle_node (ast .parse (source ))
38
38
39
39
def fq_node_name (self ) -> str :
40
40
"""Get the current fully qualified name we're processing."""
41
41
return "." .join (c .name for c in self .context )
42
42
43
- def visit_FunctionDef (self , node : ast .FunctionDef ) -> None :
43
+ def handle_node (self , node : ast .AST ) -> None :
44
+ """Recursively handle any node."""
45
+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef )):
46
+ self .handle_FunctionDef (node )
47
+ elif isinstance (node , ast .ClassDef ):
48
+ self .handle_ClassDef (node )
49
+ else :
50
+ self .handle_node_body (node )
51
+
52
+ def handle_node_body (self , node : ast .AST ) -> None :
53
+ """Recursively handle the nodes in this node's body, if any."""
54
+ for body_node in getattr (node , "body" , ()):
55
+ self .handle_node (body_node )
56
+
57
+ def handle_FunctionDef (self , node : ast .FunctionDef | ast .AsyncFunctionDef ) -> None :
44
58
"""Called for `def` or `async def`."""
45
59
lines = set (range (node .body [0 ].lineno , cast (int , node .body [- 1 ].end_lineno ) + 1 ))
46
60
if self .context and self .context [- 1 ].kind == "class" :
@@ -60,12 +74,10 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
60
74
lines = lines ,
61
75
)
62
76
)
63
- self .generic_visit (node )
77
+ self .handle_node_body (node )
64
78
self .context .pop ()
65
79
66
- visit_AsyncFunctionDef = visit_FunctionDef # type: ignore[assignment]
67
-
68
- def visit_ClassDef (self , node : ast .ClassDef ) -> None :
80
+ def handle_ClassDef (self , node : ast .ClassDef ) -> None :
69
81
"""Called for `class`."""
70
82
# The lines for a class are the lines in the methods of the class.
71
83
# We start empty, and count on visit_FunctionDef to add the lines it
@@ -80,7 +92,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
80
92
lines = lines ,
81
93
)
82
94
)
83
- self .generic_visit (node )
95
+ self .handle_node_body (node )
84
96
self .context .pop ()
85
97
# Class bodies should be excluded from the enclosing classes.
86
98
for ancestor in reversed (self .context ):
0 commit comments