7
7
"go/ast"
8
8
"go/printer"
9
9
"go/token"
10
+ "go/types"
10
11
11
12
"golang.org/x/tools/go/analysis"
12
13
"golang.org/x/tools/go/analysis/passes/inspect"
@@ -28,6 +29,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
28
29
nodeFilter := []ast.Node {
29
30
(* ast .ForStmt )(nil ),
30
31
(* ast .RangeStmt )(nil ),
32
+ (* ast .FuncLit )(nil ),
31
33
}
32
34
33
35
inspctr .Preorder (nodeFilter , func (node ast.Node ) {
@@ -36,7 +38,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
36
38
return
37
39
}
38
40
39
- assignStmt := findNestedContext (pass , body , body .List )
41
+ assignStmt := findNestedContext (pass , node , body .List )
40
42
if assignStmt == nil {
41
43
return
42
44
}
@@ -65,7 +67,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
65
67
66
68
pass .Report (analysis.Diagnostic {
67
69
Pos : assignStmt .Pos (),
68
- Message : "nested context in loop" ,
70
+ Message : getReportMessage ( node ) ,
69
71
SuggestedFixes : fixes ,
70
72
})
71
73
@@ -74,6 +76,17 @@ func run(pass *analysis.Pass) (interface{}, error) {
74
76
return nil , nil
75
77
}
76
78
79
+ func getReportMessage (node ast.Node ) string {
80
+ switch node .(type ) {
81
+ case * ast.ForStmt , * ast.RangeStmt :
82
+ return "nested context in loop"
83
+ case * ast.FuncLit :
84
+ return "nested context in function literal"
85
+ default :
86
+ return "unsupported nested context type"
87
+ }
88
+ }
89
+
77
90
func getBody (node ast.Node ) (* ast.BlockStmt , error ) {
78
91
forStmt , ok := node .(* ast.ForStmt )
79
92
if ok {
@@ -85,49 +98,54 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
85
98
return rangeStmt .Body , nil
86
99
}
87
100
101
+ funcLit , ok := node .(* ast.FuncLit )
102
+ if ok {
103
+ return funcLit .Body , nil
104
+ }
105
+
88
106
return nil , errUnknown
89
107
}
90
108
91
- func findNestedContext (pass * analysis.Pass , block * ast.BlockStmt , stmts []ast.Stmt ) * ast.AssignStmt {
109
+ func findNestedContext (pass * analysis.Pass , node ast.Node , stmts []ast.Stmt ) * ast.AssignStmt {
92
110
for _ , stmt := range stmts {
93
111
// Recurse if necessary
94
112
if inner , ok := stmt .(* ast.BlockStmt ); ok {
95
- found := findNestedContext (pass , inner , inner .List )
113
+ found := findNestedContext (pass , node , inner .List )
96
114
if found != nil {
97
115
return found
98
116
}
99
117
}
100
118
101
119
if inner , ok := stmt .(* ast.IfStmt ); ok {
102
- found := findNestedContext (pass , inner . Body , inner .Body .List )
120
+ found := findNestedContext (pass , node , inner .Body .List )
103
121
if found != nil {
104
122
return found
105
123
}
106
124
}
107
125
108
126
if inner , ok := stmt .(* ast.SwitchStmt ); ok {
109
- found := findNestedContext (pass , inner . Body , inner .Body .List )
127
+ found := findNestedContext (pass , node , inner .Body .List )
110
128
if found != nil {
111
129
return found
112
130
}
113
131
}
114
132
115
133
if inner , ok := stmt .(* ast.CaseClause ); ok {
116
- found := findNestedContext (pass , block , inner .Body )
134
+ found := findNestedContext (pass , node , inner .Body )
117
135
if found != nil {
118
136
return found
119
137
}
120
138
}
121
139
122
140
if inner , ok := stmt .(* ast.SelectStmt ); ok {
123
- found := findNestedContext (pass , inner . Body , inner .Body .List )
141
+ found := findNestedContext (pass , node , inner .Body .List )
124
142
if found != nil {
125
143
return found
126
144
}
127
145
}
128
146
129
147
if inner , ok := stmt .(* ast.CommClause ); ok {
130
- found := findNestedContext (pass , block , inner .Body )
148
+ found := findNestedContext (pass , node , inner .Body )
131
149
if found != nil {
132
150
return found
133
151
}
@@ -149,13 +167,13 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St
149
167
}
150
168
151
169
if assignStmt .Tok == token .DEFINE {
152
- break
170
+ continue
153
171
}
154
172
155
173
// allow assignment to non-pointer children of values defined within the loop
156
174
if lhs := getRootIdent (pass , assignStmt .Lhs [0 ]); lhs != nil {
157
175
if obj := pass .TypesInfo .ObjectOf (lhs ); obj != nil {
158
- if obj . Pos () >= block . Pos () && obj .Pos () < block . End ( ) {
176
+ if checkObjectScopeWithinNode ( obj .Parent (), node ) {
159
177
continue // definition is within the loop
160
178
}
161
179
}
@@ -167,6 +185,18 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St
167
185
return nil
168
186
}
169
187
188
+ func checkObjectScopeWithinNode (scope * types.Scope , node ast.Node ) bool {
189
+ if scope == nil {
190
+ return false
191
+ }
192
+
193
+ if scope .Pos () >= node .Pos () && scope .End () <= node .End () {
194
+ return true
195
+ }
196
+
197
+ return false
198
+ }
199
+
170
200
func getRootIdent (pass * analysis.Pass , node ast.Node ) * ast.Ident {
171
201
for {
172
202
switch n := node .(type ) {
0 commit comments