Skip to content

Commit 30606c7

Browse files
authored
Merge pull request #15 from Crocmagnon/feat/context-condition
handle contexts inside conditions
2 parents f35e8a2 + 0910305 commit 30606c7

File tree

2 files changed

+217
-54
lines changed

2 files changed

+217
-54
lines changed

pkg/analyzer/analyzer.go

Lines changed: 108 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -36,64 +36,39 @@ func run(pass *analysis.Pass) (interface{}, error) {
3636
return
3737
}
3838

39-
for _, stmt := range body.List {
40-
assignStmt, ok := stmt.(*ast.AssignStmt)
41-
if !ok {
42-
continue
43-
}
44-
45-
t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0])
46-
if t == nil {
47-
continue
48-
}
49-
50-
if t.String() != "context.Context" {
51-
continue
52-
}
53-
54-
if assignStmt.Tok == token.DEFINE {
55-
break
56-
}
57-
58-
// allow assignment to non-pointer children of values defined within the loop
59-
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
60-
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
61-
if obj.Pos() >= body.Pos() && obj.Pos() < body.End() {
62-
continue // definition is within the loop
63-
}
64-
}
65-
}
39+
assignStmt := findNestedContext(pass, body, body.List)
40+
if assignStmt == nil {
41+
return
42+
}
6643

67-
suggestedStmt := ast.AssignStmt{
68-
Lhs: assignStmt.Lhs,
69-
TokPos: assignStmt.TokPos,
70-
Tok: token.DEFINE,
71-
Rhs: assignStmt.Rhs,
72-
}
73-
suggested, err := render(pass.Fset, &suggestedStmt)
74-
75-
var fixes []analysis.SuggestedFix
76-
if err == nil {
77-
fixes = append(fixes, analysis.SuggestedFix{
78-
Message: "replace `=` with `:=`",
79-
TextEdits: []analysis.TextEdit{
80-
{
81-
Pos: assignStmt.Pos(),
82-
End: assignStmt.End(),
83-
NewText: []byte(suggested),
84-
},
44+
suggestedStmt := ast.AssignStmt{
45+
Lhs: assignStmt.Lhs,
46+
TokPos: assignStmt.TokPos,
47+
Tok: token.DEFINE,
48+
Rhs: assignStmt.Rhs,
49+
}
50+
suggested, err := render(pass.Fset, &suggestedStmt)
51+
52+
var fixes []analysis.SuggestedFix
53+
if err == nil {
54+
fixes = append(fixes, analysis.SuggestedFix{
55+
Message: "replace `=` with `:=`",
56+
TextEdits: []analysis.TextEdit{
57+
{
58+
Pos: assignStmt.Pos(),
59+
End: assignStmt.End(),
60+
NewText: []byte(suggested),
8561
},
86-
})
87-
}
88-
89-
pass.Report(analysis.Diagnostic{
90-
Pos: assignStmt.Pos(),
91-
Message: "nested context in loop",
92-
SuggestedFixes: fixes,
62+
},
9363
})
94-
95-
break
9664
}
65+
66+
pass.Report(analysis.Diagnostic{
67+
Pos: assignStmt.Pos(),
68+
Message: "nested context in loop",
69+
SuggestedFixes: fixes,
70+
})
71+
9772
})
9873

9974
return nil, nil
@@ -113,6 +88,85 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
11388
return nil, errUnknown
11489
}
11590

91+
func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt {
92+
for _, stmt := range stmts {
93+
// Recurse if necessary
94+
if inner, ok := stmt.(*ast.BlockStmt); ok {
95+
found := findNestedContext(pass, inner, inner.List)
96+
if found != nil {
97+
return found
98+
}
99+
}
100+
101+
if inner, ok := stmt.(*ast.IfStmt); ok {
102+
found := findNestedContext(pass, inner.Body, inner.Body.List)
103+
if found != nil {
104+
return found
105+
}
106+
}
107+
108+
if inner, ok := stmt.(*ast.SwitchStmt); ok {
109+
found := findNestedContext(pass, inner.Body, inner.Body.List)
110+
if found != nil {
111+
return found
112+
}
113+
}
114+
115+
if inner, ok := stmt.(*ast.CaseClause); ok {
116+
found := findNestedContext(pass, block, inner.Body)
117+
if found != nil {
118+
return found
119+
}
120+
}
121+
122+
if inner, ok := stmt.(*ast.SelectStmt); ok {
123+
found := findNestedContext(pass, inner.Body, inner.Body.List)
124+
if found != nil {
125+
return found
126+
}
127+
}
128+
129+
if inner, ok := stmt.(*ast.CommClause); ok {
130+
found := findNestedContext(pass, block, inner.Body)
131+
if found != nil {
132+
return found
133+
}
134+
}
135+
136+
// Actually check for nested context
137+
assignStmt, ok := stmt.(*ast.AssignStmt)
138+
if !ok {
139+
continue
140+
}
141+
142+
t := pass.TypesInfo.TypeOf(assignStmt.Lhs[0])
143+
if t == nil {
144+
continue
145+
}
146+
147+
if t.String() != "context.Context" {
148+
continue
149+
}
150+
151+
if assignStmt.Tok == token.DEFINE {
152+
break
153+
}
154+
155+
// allow assignment to non-pointer children of values defined within the loop
156+
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
157+
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
158+
if obj.Pos() >= block.Pos() && obj.Pos() < block.End() {
159+
continue // definition is within the loop
160+
}
161+
}
162+
}
163+
164+
return assignStmt
165+
}
166+
167+
return nil
168+
}
169+
116170
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
117171
for {
118172
switch n := node.(type) {

testdata/src/example.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,50 @@ func example() {
2525
ctx = wrapContext(ctx) // want "nested context in loop"
2626
break
2727
}
28+
29+
// not fooled by shadowing in nested blocks
30+
for {
31+
err := doSomething()
32+
if err != nil {
33+
ctx := wrapContext(ctx)
34+
ctx = wrapContext(ctx)
35+
}
36+
37+
switch err {
38+
case nil:
39+
ctx := wrapContext(ctx)
40+
ctx = wrapContext(ctx)
41+
default:
42+
ctx := wrapContext(ctx)
43+
ctx = wrapContext(ctx)
44+
}
45+
46+
{
47+
ctx := wrapContext(ctx)
48+
ctx = wrapContext(ctx)
49+
}
50+
51+
select {
52+
case <-ctx.Done():
53+
ctx := wrapContext(ctx)
54+
ctx = wrapContext(ctx)
55+
default:
56+
}
57+
58+
ctx = wrapContext(ctx) // want "nested context in loop"
59+
60+
break
61+
}
2862
}
2963

3064
func wrapContext(ctx context.Context) context.Context {
3165
return context.WithoutCancel(ctx)
3266
}
3367

68+
func doSomething() error {
69+
return nil
70+
}
71+
3472
// storing contexts in a struct isn't recommended, but local copies of a non-pointer struct should act like local copies of a context.
3573
func inStructs(ctx context.Context) {
3674
for i := 0; i < 10; i++ {
@@ -71,3 +109,74 @@ func inStructs(ctx context.Context) {
71109
rp[0].Ctx = context.WithValue(rp[0].Ctx, "other", "val")
72110
}
73111
}
112+
113+
func inVariousNestedBlocks(ctx context.Context) {
114+
for {
115+
err := doSomething()
116+
if err != nil {
117+
ctx = wrapContext(ctx) // want "nested context in loop"
118+
}
119+
120+
break
121+
}
122+
123+
for {
124+
err := doSomething()
125+
if err != nil {
126+
if true {
127+
ctx = wrapContext(ctx) // want "nested context in loop"
128+
}
129+
}
130+
131+
break
132+
}
133+
134+
for {
135+
err := doSomething()
136+
switch err {
137+
case nil:
138+
ctx = wrapContext(ctx) // want "nested context in loop"
139+
}
140+
141+
break
142+
}
143+
144+
for {
145+
err := doSomething()
146+
switch err {
147+
default:
148+
ctx = wrapContext(ctx) // want "nested context in loop"
149+
}
150+
151+
break
152+
}
153+
154+
for {
155+
ctx := wrapContext(ctx)
156+
157+
err := doSomething()
158+
if err != nil {
159+
ctx = wrapContext(ctx)
160+
}
161+
162+
break
163+
}
164+
165+
for {
166+
{
167+
ctx = wrapContext(ctx) // want "nested context in loop"
168+
}
169+
170+
break
171+
}
172+
173+
for {
174+
select {
175+
case <-ctx.Done():
176+
ctx = wrapContext(ctx) // want "nested context in loop"
177+
default:
178+
}
179+
180+
break
181+
}
182+
}

0 commit comments

Comments
 (0)