@@ -31,8 +31,22 @@ class ScopeNode:
31
31
"""
32
32
33
33
node : astroid .Module | astroid .FunctionDef | astroid .ClassDef | astroid .AssignName | astroid .AssignAttr | astroid .Attribute | astroid .Call | astroid .Import | astroid .ImportFrom | MemberAccess
34
- children : list [ScopeNode ] | None = None
35
- parent : ScopeNode | None = None
34
+ children : list [ScopeNode | ClassScopeNode ]
35
+ parent : ScopeNode | ClassScopeNode | None = None
36
+
37
+
38
+ @dataclass
39
+ class ClassScopeNode (ScopeNode ):
40
+ """Represents a ScopeNode that defines the scope of a class.
41
+
42
+ Attributes
43
+ ----------
44
+ class_variables is a list of AssignName nodes that define class variables
45
+ instance_variables is a list of AssignAttr nodes that define instance variables
46
+ """
47
+
48
+ class_variables : list [astroid .AssignName ] = field (default_factory = list )
49
+ instance_variables : list [astroid .AssignAttr ] = field (default_factory = list )
36
50
37
51
38
52
@dataclass
@@ -49,8 +63,26 @@ class ScopeFinder:
49
63
children: All found children nodes are stored in children until their scope is determined.
50
64
"""
51
65
52
- current_node_stack : list [ScopeNode ] = field (default_factory = list )
53
- children : list [ScopeNode ] = field (default_factory = list )
66
+ current_node_stack : list [ScopeNode | ClassScopeNode ] = field (default_factory = list )
67
+ children : list [ScopeNode | ClassScopeNode ] = field (default_factory = list )
68
+
69
+ def get_node_by_name (self , name : str ) -> ScopeNode | ClassScopeNode | None :
70
+ """
71
+ Get a ScopeNode by its name.
72
+
73
+ Parameters
74
+ ----------
75
+ name is the name of the node that should be found.
76
+
77
+ Returns
78
+ -------
79
+ The ScopeNode with the given name, or None if no node with the given name was found.
80
+ """
81
+ for node in self .current_node_stack :
82
+ if node .node .name == name :
83
+ return node
84
+ return None
85
+ # TODO: this is inefficient, instead use a dict to store the nodes
54
86
55
87
def detect_scope (self , node : astroid .NodeNG ) -> None :
56
88
"""
@@ -60,8 +92,8 @@ def detect_scope(self, node: astroid.NodeNG) -> None:
60
92
The scope of a node is defined by the parent node in the scope tree.
61
93
"""
62
94
current_scope = node
63
- outer_scope_children : list [ScopeNode ] = []
64
- inner_scope_children : list [ScopeNode ] = []
95
+ outer_scope_children : list [ScopeNode | ClassScopeNode ] = []
96
+ inner_scope_children : list [ScopeNode | ClassScopeNode ] = []
65
97
for child in self .children :
66
98
if (
67
99
child .parent is not None and child .parent .node != current_scope
@@ -75,6 +107,24 @@ def detect_scope(self, node: astroid.NodeNG) -> None:
75
107
self .children .append (self .current_node_stack [- 1 ]) # add the current node to the children
76
108
self .current_node_stack .pop () # remove the current node from the stack
77
109
110
+ def analyze_constructor (self , node : astroid .FunctionDef ) -> None :
111
+ """Analyze the constructor of a class.
112
+
113
+ The constructor of a class is a special function that is called when an instance of the class is created.
114
+ This function only is called when the name of the FunctionDef node is `__init__`.
115
+ """
116
+ # add instance variables to the instance_variables list of the class
117
+ for child in node .body :
118
+ class_node = self .get_node_by_name (node .parent .name )
119
+
120
+ if isinstance (class_node , ClassScopeNode ):
121
+ if isinstance (child , astroid .Assign ):
122
+ class_node .instance_variables .append (child .targets [0 ])
123
+ elif isinstance (child , astroid .AnnAssign ):
124
+ class_node .instance_variables .append (child .target )
125
+ else :
126
+ raise TypeError (f"Unexpected node type { type (child )} " )
127
+
78
128
def enter_module (self , node : astroid .Module ) -> None :
79
129
"""
80
130
Enter a module node.
@@ -83,25 +133,32 @@ def enter_module(self, node: astroid.Module) -> None:
83
133
The module node is also the first node that is visited, so the current_node_stack is empty before entering the module node.
84
134
"""
85
135
self .current_node_stack .append (
86
- ScopeNode (node = node , children = None , parent = None ),
136
+ ScopeNode (node = node , children = [] , parent = None ),
87
137
)
88
138
89
139
def leave_module (self , node : astroid .Module ) -> None :
90
140
self .detect_scope (node )
91
141
92
142
def enter_classdef (self , node : astroid .ClassDef ) -> None :
93
143
self .current_node_stack .append (
94
- ScopeNode (node = node , children = None , parent = self .current_node_stack [- 1 ]),
144
+ ClassScopeNode (
145
+ node = node ,
146
+ children = [],
147
+ parent = self .current_node_stack [- 1 ],
148
+ instance_variables = [],
149
+ class_variables = [],
150
+ ),
95
151
)
96
152
97
153
def leave_classdef (self , node : astroid .ClassDef ) -> None :
98
154
self .detect_scope (node )
99
155
100
156
def enter_functiondef (self , node : astroid .FunctionDef ) -> None :
101
157
self .current_node_stack .append (
102
- ScopeNode (node = node , children = None , parent = self .current_node_stack [- 1 ]),
158
+ ScopeNode (node = node , children = [] , parent = self .current_node_stack [- 1 ]),
103
159
)
104
- # TODO: Special treatment for __init__ function
160
+ if node .name == "__init__" :
161
+ self .analyze_constructor (node )
105
162
106
163
def leave_functiondef (self , node : astroid .FunctionDef ) -> None :
107
164
self .detect_scope (node )
@@ -120,26 +177,42 @@ def enter_assignname(self, node: astroid.AssignName) -> None:
120
177
| astroid .AnnAssign ,
121
178
):
122
179
parent = self .current_node_stack [- 1 ]
123
- scope_node = ScopeNode (node = node , children = None , parent = parent )
180
+ scope_node = ScopeNode (node = node , children = [] , parent = parent )
124
181
self .children .append (scope_node )
125
182
126
- def enter_assignattr (self , node : astroid .Attribute ) -> None :
183
+ # add class variables to the class_variables list of the class
184
+ if isinstance (node .parent .parent , astroid .ClassDef ):
185
+ class_node = self .get_node_by_name (node .parent .parent .name )
186
+ if isinstance (class_node , ClassScopeNode ):
187
+ class_node .class_variables .append (node )
188
+
189
+ def enter_assignattr (self , node : astroid .AssignAttr ) -> None :
127
190
parent = self .current_node_stack [- 1 ]
128
- scope_node = ScopeNode (node = node , children = None , parent = parent )
191
+ scope_node = ScopeNode (node = node , children = [] , parent = parent )
129
192
self .children .append (scope_node )
130
193
131
194
def enter_import (self , node : astroid .Import ) -> None :
132
195
parent = self .current_node_stack [- 1 ]
133
- scope_node = ScopeNode (node = node , children = None , parent = parent )
196
+ scope_node = ScopeNode (node = node , children = [] , parent = parent )
134
197
self .children .append (scope_node )
135
198
136
199
def enter_importfrom (self , node : astroid .ImportFrom ) -> None :
137
200
parent = self .current_node_stack [- 1 ]
138
- scope_node = ScopeNode (node = node , children = None , parent = parent )
201
+ scope_node = ScopeNode (node = node , children = [] , parent = parent )
139
202
self .children .append (scope_node )
140
203
141
204
142
- def get_scope (code : str ) -> list [ScopeNode ]:
205
+ def get_scope (code : str ) -> list [ScopeNode | ClassScopeNode ]:
206
+ """Get the scope of the given code.
207
+
208
+ In order to get the scope of the given code, the code is parsed into an AST and then walked by an ASTWalker.
209
+ The ASTWalker detects the scope of each node and builds a scope tree by using an instance of ScopeFinder.
210
+
211
+ Returns
212
+ -------
213
+ scopes: list of ScopeNode instances that represent the scope tree of the given code.
214
+ variables: list of class variables and list of instance variables for all classes in the given code.
215
+ """
143
216
scope_handler = ScopeFinder ()
144
217
walker = ASTWalker (scope_handler )
145
218
module = astroid .parse (code )
0 commit comments