Skip to content

Commit be0aa70

Browse files
authored
Detect nested contexts in function literals (#18)
* feat: Add detection for nested contexts in function literals * feat: Improve detection of nested contexts in function literals * refactor: Update getReportMessage function to handle unsupported nested context types * use node instead of block * refactor: use multi case * added one more case * feat: also added support for multiple contexts
1 parent 0d2c401 commit be0aa70

File tree

2 files changed

+89
-11
lines changed

2 files changed

+89
-11
lines changed

pkg/analyzer/analyzer.go

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"go/ast"
88
"go/printer"
99
"go/token"
10+
"go/types"
1011

1112
"golang.org/x/tools/go/analysis"
1213
"golang.org/x/tools/go/analysis/passes/inspect"
@@ -28,6 +29,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
2829
nodeFilter := []ast.Node{
2930
(*ast.ForStmt)(nil),
3031
(*ast.RangeStmt)(nil),
32+
(*ast.FuncLit)(nil),
3133
}
3234

3335
inspctr.Preorder(nodeFilter, func(node ast.Node) {
@@ -36,7 +38,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
3638
return
3739
}
3840

39-
assignStmt := findNestedContext(pass, body, body.List)
41+
assignStmt := findNestedContext(pass, node, body.List)
4042
if assignStmt == nil {
4143
return
4244
}
@@ -65,7 +67,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
6567

6668
pass.Report(analysis.Diagnostic{
6769
Pos: assignStmt.Pos(),
68-
Message: "nested context in loop",
70+
Message: getReportMessage(node),
6971
SuggestedFixes: fixes,
7072
})
7173

@@ -74,6 +76,17 @@ func run(pass *analysis.Pass) (interface{}, error) {
7476
return nil, nil
7577
}
7678

79+
func getReportMessage(node ast.Node) string {
80+
switch node.(type) {
81+
case *ast.ForStmt, *ast.RangeStmt:
82+
return "nested context in loop"
83+
case *ast.FuncLit:
84+
return "nested context in function literal"
85+
default:
86+
return "unsupported nested context type"
87+
}
88+
}
89+
7790
func getBody(node ast.Node) (*ast.BlockStmt, error) {
7891
forStmt, ok := node.(*ast.ForStmt)
7992
if ok {
@@ -85,49 +98,54 @@ func getBody(node ast.Node) (*ast.BlockStmt, error) {
8598
return rangeStmt.Body, nil
8699
}
87100

101+
funcLit, ok := node.(*ast.FuncLit)
102+
if ok {
103+
return funcLit.Body, nil
104+
}
105+
88106
return nil, errUnknown
89107
}
90108

91-
func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.Stmt) *ast.AssignStmt {
109+
func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *ast.AssignStmt {
92110
for _, stmt := range stmts {
93111
// Recurse if necessary
94112
if inner, ok := stmt.(*ast.BlockStmt); ok {
95-
found := findNestedContext(pass, inner, inner.List)
113+
found := findNestedContext(pass, node, inner.List)
96114
if found != nil {
97115
return found
98116
}
99117
}
100118

101119
if inner, ok := stmt.(*ast.IfStmt); ok {
102-
found := findNestedContext(pass, inner.Body, inner.Body.List)
120+
found := findNestedContext(pass, node, inner.Body.List)
103121
if found != nil {
104122
return found
105123
}
106124
}
107125

108126
if inner, ok := stmt.(*ast.SwitchStmt); ok {
109-
found := findNestedContext(pass, inner.Body, inner.Body.List)
127+
found := findNestedContext(pass, node, inner.Body.List)
110128
if found != nil {
111129
return found
112130
}
113131
}
114132

115133
if inner, ok := stmt.(*ast.CaseClause); ok {
116-
found := findNestedContext(pass, block, inner.Body)
134+
found := findNestedContext(pass, node, inner.Body)
117135
if found != nil {
118136
return found
119137
}
120138
}
121139

122140
if inner, ok := stmt.(*ast.SelectStmt); ok {
123-
found := findNestedContext(pass, inner.Body, inner.Body.List)
141+
found := findNestedContext(pass, node, inner.Body.List)
124142
if found != nil {
125143
return found
126144
}
127145
}
128146

129147
if inner, ok := stmt.(*ast.CommClause); ok {
130-
found := findNestedContext(pass, block, inner.Body)
148+
found := findNestedContext(pass, node, inner.Body)
131149
if found != nil {
132150
return found
133151
}
@@ -149,13 +167,13 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St
149167
}
150168

151169
if assignStmt.Tok == token.DEFINE {
152-
break
170+
continue
153171
}
154172

155173
// allow assignment to non-pointer children of values defined within the loop
156174
if lhs := getRootIdent(pass, assignStmt.Lhs[0]); lhs != nil {
157175
if obj := pass.TypesInfo.ObjectOf(lhs); obj != nil {
158-
if obj.Pos() >= block.Pos() && obj.Pos() < block.End() {
176+
if checkObjectScopeWithinNode(obj.Parent(), node) {
159177
continue // definition is within the loop
160178
}
161179
}
@@ -167,6 +185,18 @@ func findNestedContext(pass *analysis.Pass, block *ast.BlockStmt, stmts []ast.St
167185
return nil
168186
}
169187

188+
func checkObjectScopeWithinNode(scope *types.Scope, node ast.Node) bool {
189+
if scope == nil {
190+
return false
191+
}
192+
193+
if scope.Pos() >= node.Pos() && scope.End() <= node.End() {
194+
return true
195+
}
196+
197+
return false
198+
}
199+
170200
func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
171201
for {
172202
switch n := node.(type) {

testdata/src/example.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,26 @@ func example() {
5959

6060
break
6161
}
62+
63+
// detects contexts wrapped in function literals (this is risky as function literals can be called multiple times)
64+
_ = func() {
65+
ctx = wrapContext(ctx) // want "nested context in function literal"
66+
}
67+
68+
// this is fine because the context is created in the loop
69+
for {
70+
if ctx := context.Background(); doSomething() != nil {
71+
ctx = wrapContext(ctx)
72+
}
73+
}
74+
75+
for {
76+
ctx2 := context.Background()
77+
ctx = wrapContext(ctx) // want "nested context in loop"
78+
if doSomething() != nil {
79+
ctx2 = wrapContext(ctx2)
80+
}
81+
}
6282
}
6383

6484
func wrapContext(ctx context.Context) context.Context {
@@ -180,3 +200,31 @@ func inVariousNestedBlocks(ctx context.Context) {
180200
break
181201
}
182202
}
203+
204+
// this middleware could run on every request, bloating the request parameter level context and causing a memory leak
205+
func badMiddleware(ctx context.Context) func() error {
206+
return func() error {
207+
ctx = wrapContext(ctx) // want "nested context in function literal"
208+
return doSomethingWithCtx(ctx)
209+
}
210+
}
211+
212+
// this middleware is fine, as it doesn't modify the context of parent function
213+
func okMiddleware(ctx context.Context) func() error {
214+
return func() error {
215+
ctx := wrapContext(ctx)
216+
return doSomethingWithCtx(ctx)
217+
}
218+
}
219+
220+
// this middleware is fine, as it only modifies the context passed to it
221+
func okMiddleware2(ctx context.Context) func(ctx context.Context) error {
222+
return func(ctx context.Context) error {
223+
ctx = wrapContext(ctx)
224+
return doSomethingWithCtx(ctx)
225+
}
226+
}
227+
228+
func doSomethingWithCtx(ctx context.Context) error {
229+
return nil
230+
}

0 commit comments

Comments
 (0)