@@ -36,64 +36,39 @@ func run(pass *analysis.Pass) (interface{}, error) {
36
36
return
37
37
}
38
38
39
- for _ , stmt := range body .List {
40
- assignStmt , ok := stmt .(* ast.AssignStmt )
41
- if ! ok {
42
- continue
43
- }
44
-
45
- t := pass .TypesInfo .TypeOf (assignStmt .Lhs [0 ])
46
- if t == nil {
47
- continue
48
- }
49
-
50
- if t .String () != "context.Context" {
51
- continue
52
- }
53
-
54
- if assignStmt .Tok == token .DEFINE {
55
- break
56
- }
57
-
58
- // allow assignment to non-pointer children of values defined within the loop
59
- if lhs := getRootIdent (pass , assignStmt .Lhs [0 ]); lhs != nil {
60
- if obj := pass .TypesInfo .ObjectOf (lhs ); obj != nil {
61
- if obj .Pos () >= body .Pos () && obj .Pos () < body .End () {
62
- continue // definition is within the loop
63
- }
64
- }
65
- }
39
+ assignStmt := findNestedContext (pass , body , body .List )
40
+ if assignStmt == nil {
41
+ return
42
+ }
66
43
67
- suggestedStmt := ast.AssignStmt {
68
- Lhs : assignStmt .Lhs ,
69
- TokPos : assignStmt .TokPos ,
70
- Tok : token .DEFINE ,
71
- Rhs : assignStmt .Rhs ,
72
- }
73
- suggested , err := render (pass .Fset , & suggestedStmt )
74
-
75
- var fixes []analysis.SuggestedFix
76
- if err == nil {
77
- fixes = append (fixes , analysis.SuggestedFix {
78
- Message : "replace `=` with `:=`" ,
79
- TextEdits : []analysis.TextEdit {
80
- {
81
- Pos : assignStmt .Pos (),
82
- End : assignStmt .End (),
83
- NewText : []byte (suggested ),
84
- },
44
+ suggestedStmt := ast.AssignStmt {
45
+ Lhs : assignStmt .Lhs ,
46
+ TokPos : assignStmt .TokPos ,
47
+ Tok : token .DEFINE ,
48
+ Rhs : assignStmt .Rhs ,
49
+ }
50
+ suggested , err := render (pass .Fset , & suggestedStmt )
51
+
52
+ var fixes []analysis.SuggestedFix
53
+ if err == nil {
54
+ fixes = append (fixes , analysis.SuggestedFix {
55
+ Message : "replace `=` with `:=`" ,
56
+ TextEdits : []analysis.TextEdit {
57
+ {
58
+ Pos : assignStmt .Pos (),
59
+ End : assignStmt .End (),
60
+ NewText : []byte (suggested ),
85
61
},
86
- })
87
- }
88
-
89
- pass .Report (analysis.Diagnostic {
90
- Pos : assignStmt .Pos (),
91
- Message : "nested context in loop" ,
92
- SuggestedFixes : fixes ,
62
+ },
93
63
})
94
-
95
- break
96
64
}
65
+
66
+ pass .Report (analysis.Diagnostic {
67
+ Pos : assignStmt .Pos (),
68
+ Message : "nested context in loop" ,
69
+ SuggestedFixes : fixes ,
70
+ })
71
+
97
72
})
98
73
99
74
return nil , nil
@@ -113,6 +88,85 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
113
88
return nil , errUnknown
114
89
}
115
90
91
+ func findNestedContext (pass * analysis.Pass , block * ast.BlockStmt , stmts []ast.Stmt ) * ast.AssignStmt {
92
+ for _ , stmt := range stmts {
93
+ // Recurse if necessary
94
+ if inner , ok := stmt .(* ast.BlockStmt ); ok {
95
+ found := findNestedContext (pass , inner , inner .List )
96
+ if found != nil {
97
+ return found
98
+ }
99
+ }
100
+
101
+ if inner , ok := stmt .(* ast.IfStmt ); ok {
102
+ found := findNestedContext (pass , inner .Body , inner .Body .List )
103
+ if found != nil {
104
+ return found
105
+ }
106
+ }
107
+
108
+ if inner , ok := stmt .(* ast.SwitchStmt ); ok {
109
+ found := findNestedContext (pass , inner .Body , inner .Body .List )
110
+ if found != nil {
111
+ return found
112
+ }
113
+ }
114
+
115
+ if inner , ok := stmt .(* ast.CaseClause ); ok {
116
+ found := findNestedContext (pass , block , inner .Body )
117
+ if found != nil {
118
+ return found
119
+ }
120
+ }
121
+
122
+ if inner , ok := stmt .(* ast.SelectStmt ); ok {
123
+ found := findNestedContext (pass , inner .Body , inner .Body .List )
124
+ if found != nil {
125
+ return found
126
+ }
127
+ }
128
+
129
+ if inner , ok := stmt .(* ast.CommClause ); ok {
130
+ found := findNestedContext (pass , block , inner .Body )
131
+ if found != nil {
132
+ return found
133
+ }
134
+ }
135
+
136
+ // Actually check for nested context
137
+ assignStmt , ok := stmt .(* ast.AssignStmt )
138
+ if ! ok {
139
+ continue
140
+ }
141
+
142
+ t := pass .TypesInfo .TypeOf (assignStmt .Lhs [0 ])
143
+ if t == nil {
144
+ continue
145
+ }
146
+
147
+ if t .String () != "context.Context" {
148
+ continue
149
+ }
150
+
151
+ if assignStmt .Tok == token .DEFINE {
152
+ break
153
+ }
154
+
155
+ // allow assignment to non-pointer children of values defined within the loop
156
+ if lhs := getRootIdent (pass , assignStmt .Lhs [0 ]); lhs != nil {
157
+ if obj := pass .TypesInfo .ObjectOf (lhs ); obj != nil {
158
+ if obj .Pos () >= block .Pos () && obj .Pos () < block .End () {
159
+ continue // definition is within the loop
160
+ }
161
+ }
162
+ }
163
+
164
+ return assignStmt
165
+ }
166
+
167
+ return nil
168
+ }
169
+
116
170
func getRootIdent (pass * analysis.Pass , node ast.Node ) * ast.Ident {
117
171
for {
118
172
switch n := node .(type ) {
0 commit comments