Skip to content

Commit ca63215

Browse files
committed
feat: support check r.Context()
1 parent a63b6f4 commit ca63215

File tree

2 files changed

+78
-116
lines changed

2 files changed

+78
-116
lines changed

contextcheck.go

+63-110
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,14 @@ const (
3838
CtxIn int = 1 << iota // ctx in function's param
3939
CtxOut // ctx in function's results
4040
CtxInField // ctx in function's field param
41-
HttpRes // http.ResponseWriter in function's param
42-
HttpReq // *http.Request in function's param
43-
44-
HttpHandler = HttpRes | HttpReq
4541
)
4642

47-
const (
48-
EntryWithCtx int = 1 << iota // has ctx in
49-
EntryWithHttpHandler // is http handler
43+
type entryType int
5044

51-
Entry = EntryWithCtx | EntryWithHttpHandler
45+
const (
46+
EntryNone entryType = iota
47+
EntryWithCtx // has ctx in
48+
EntryWithHttpHandler // is http handler
5249
)
5350

5451
type resInfo struct {
@@ -109,7 +106,7 @@ func (r *runner) run(pass *analysis.Pass) {
109106

110107
type entryInfo struct {
111108
f *ssa.Function // entryfunc
112-
tp int // entrytype
109+
tp entryType // entrytype
113110
}
114111
var tmpFuncs []entryInfo
115112
for _, f := range funcs {
@@ -119,7 +116,7 @@ func (r *runner) run(pass *analysis.Pass) {
119116
continue
120117
}
121118

122-
if entryType := r.checkIsEntry(f); entryType&Entry == 0 {
119+
if entryType := r.checkIsEntry(f); entryType == EntryNone {
123120
// record the result of nomal function
124121
checkingMap := make(map[string]bool)
125122
checkingMap[key] = true
@@ -161,7 +158,7 @@ func (r *runner) getRequiedType(pssa *buildssa.SSA, path, name string) (obj *typ
161158
func (r *runner) collectHttpTyps(pssa *buildssa.SSA) {
162159
objRes, pobjRes, ok := r.getRequiedType(pssa, httpPkg, httpRes)
163160
if ok {
164-
r.httpResTyps = append(r.httpResTyps, objRes, pobjRes, types.NewPointer(pobjRes))
161+
r.httpResTyps = append(r.httpResTyps, objRes, pobjRes)
165162
}
166163

167164
objReq, pobjReq, ok := r.getRequiedType(pssa, httpPkg, httpReq)
@@ -201,27 +198,26 @@ func (r *runner) noImportedContextAndHttp(f *ssa.Function) (ret bool) {
201198
return true
202199
}
203200

204-
func (r *runner) checkIsEntry(f *ssa.Function) (entryType int) {
201+
func (r *runner) checkIsEntry(f *ssa.Function) entryType {
205202
if r.noImportedContextAndHttp(f) {
206-
return
203+
return EntryNone
207204
}
208205

209206
ctxIn, ctxOut := r.checkIsCtx(f)
210207
if ctxOut {
211208
// skip the function which generate ctx
212-
return
209+
return EntryNone
213210
} else if ctxIn {
214211
// has ctx in, ignore *http.Request.Context()
215-
entryType |= EntryWithCtx
216-
return
212+
return EntryWithCtx
217213
}
218214

219215
// check is `func handler(w http.ResponseWriter, r *http.Request) {}`
220216
if r.checkIsHttpHandler(f) {
221-
entryType |= EntryWithHttpHandler
217+
return EntryWithHttpHandler
222218
}
223219

224-
return
220+
return EntryNone
225221
}
226222

227223
func (r *runner) checkIsCtx(f *ssa.Function) (in, out bool) {
@@ -259,39 +255,12 @@ func (r *runner) checkIsHttpHandler(f *ssa.Function) bool {
259255
return false
260256
}
261257

262-
// must has http.ResponseWriter and *http.Request in param or freevar
263-
var tp int
264-
265-
// check params
258+
// must be `func f(w http.ResponseWriter, r *http.Request) {}`
266259
tuple := f.Signature.Params()
267-
for i := 0; i < tuple.Len(); i++ {
268-
if r.isCtxType(tuple.At(i).Type()) {
269-
return false
270-
} else if r.isHttpReqType(tuple.At(i).Type()) {
271-
tp |= HttpReq
272-
} else if r.isHttpResType(tuple.At(i).Type()) {
273-
tp |= HttpRes
274-
}
275-
if tp == HttpHandler {
276-
return true
277-
}
278-
}
279-
280-
// check freevars
281-
for _, param := range f.FreeVars {
282-
if r.isCtxType(param.Type()) {
283-
return false
284-
} else if r.isHttpReqType(param.Type()) {
285-
tp |= HttpReq
286-
} else if r.isHttpResType(param.Type()) {
287-
tp |= HttpRes
288-
}
289-
if tp == HttpHandler {
290-
return true
291-
}
260+
if tuple.Len() != 2 {
261+
return false
292262
}
293-
294-
return false
263+
return r.isHttpResType(tuple.At(0).Type()) && r.isHttpReqType(tuple.At(1).Type())
295264
}
296265

297266
func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[ssa.Instruction]bool, ok bool) {
@@ -358,15 +327,21 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[
358327
}
359328
}
360329

361-
for _, param := range f.Params {
362-
if r.isCtxType(param.Type()) {
363-
checkRefs(param, false)
330+
if isHttpHandler {
331+
for _, v := range r.getHttpReqCtx(f) {
332+
checkRefs(v, false)
333+
}
334+
} else {
335+
for _, param := range f.Params {
336+
if r.isCtxType(param.Type()) {
337+
checkRefs(param, false)
338+
}
364339
}
365-
}
366340

367-
for _, param := range f.FreeVars {
368-
if r.isCtxType(param.Type()) {
369-
checkRefs(param, false)
341+
for _, param := range f.FreeVars {
342+
if r.isCtxType(param.Type()) {
343+
checkRefs(param, false)
344+
}
370345
}
371346
}
372347

@@ -386,14 +361,6 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[
386361
}
387362
}
388363

389-
if !isHttpHandler {
390-
return
391-
}
392-
393-
for _, v := range r.getHttpReqCtx(f) {
394-
checkRefs(v, false)
395-
}
396-
397364
return
398365
}
399366

@@ -421,40 +388,34 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) {
421388
checkInstr = func(instr ssa.Instruction, fromAddr bool) {
422389
switch i := instr.(type) {
423390
case ssa.CallInstruction:
391+
// r.Context() only has one recv
392+
if len(i.Common().Args) != 1 {
393+
break
394+
}
395+
424396
// find r.Context()
425397
if r.getCallInstrCtxType(i)&CtxOut != CtxOut {
426398
break
427399
}
428400

429-
for _, v := range i.Common().Args {
430-
if !r.isHttpReqType(v.Type()) {
431-
continue
432-
}
433-
434-
f := r.getFunction(instr)
435-
if f == nil {
436-
continue
437-
}
438-
439-
// check is r.Context
440-
if f.Signature.Recv() != nil && r.isHttpReqType(f.Signature.Recv().Type()) && f.Name() == ctxName {
441-
// collect the return of r.Context
442-
rets = append(rets, i.Value())
443-
}
401+
// check is r.Context
402+
f := r.getFunction(instr)
403+
if f == nil || f.Name() != ctxName {
404+
break
405+
}
406+
if f.Signature.Recv() != nil {
407+
// collect the return of r.Context
408+
rets = append(rets, i.Value())
444409
}
445410
case *ssa.Store:
446411
if !fromAddr {
447412
checkRefs(i.Addr, true)
448413
}
449414
case *ssa.UnOp:
450-
if r.isHttpReqType(i.Type()) {
451-
checkRefs(i, false)
452-
}
453-
case *ssa.MakeClosure:
415+
checkRefs(i, false)
454416
case *ssa.Phi:
455-
if r.isHttpReqType(i.Type()) {
456-
checkRefs(i, false)
457-
}
417+
checkRefs(i, false)
418+
case *ssa.MakeClosure:
458419
case *ssa.Extract:
459420
// http.Request can only be input
460421
}
@@ -463,20 +424,15 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) {
463424
for _, param := range f.Params {
464425
if r.isHttpReqType(param.Type()) {
465426
checkRefs(param, false)
466-
}
467-
}
468-
469-
for _, param := range f.FreeVars {
470-
if r.isHttpReqType(param.Type()) {
471-
checkRefs(param, false)
427+
break
472428
}
473429
}
474430

475431
return
476432
}
477433

478-
func (r *runner) checkFuncWithCtx(f *ssa.Function, tp int) {
479-
isHttpHandler := tp&EntryWithHttpHandler != 0
434+
func (r *runner) checkFuncWithCtx(f *ssa.Function, tp entryType) {
435+
isHttpHandler := tp == EntryWithHttpHandler
480436
refMap, ok := r.collectCtxRef(f, isHttpHandler)
481437
if !ok {
482438
return
@@ -496,15 +452,14 @@ func (r *runner) checkFuncWithCtx(f *ssa.Function, tp int) {
496452

497453
if tp&CtxIn != 0 {
498454
if !refMap[instr] {
499-
r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead")
455+
if isHttpHandler {
456+
r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead")
457+
} else {
458+
r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` instead")
459+
}
500460
}
501461
}
502462

503-
// only check if the ctx used in the current function is r.Context()
504-
if isHttpHandler {
505-
continue
506-
}
507-
508463
ff := r.getFunction(instr)
509464
if ff == nil {
510465
continue
@@ -564,13 +519,13 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo
564519
continue
565520
}
566521

567-
if entryType := r.checkIsEntry(ff); entryType&Entry == 0 {
522+
if entryType := r.checkIsEntry(ff); entryType == EntryNone {
568523
// cannot get info from fact, skip
569524
if ff.Blocks == nil {
570525
continue
571526
}
572527

573-
// handler ring call
528+
// handler cycle call
574529
if checkingMap[key] {
575530
continue
576531
}
@@ -681,23 +636,21 @@ func (r *runner) isCtxType(tp types.Type) bool {
681636
}
682637

683638
func (r *runner) isHttpResType(tp types.Type) bool {
684-
var ok bool
685639
for _, v := range r.httpResTyps {
686-
if ok = types.Identical(v, v); ok {
687-
break
640+
if ok := types.Identical(v, v); ok {
641+
return true
688642
}
689643
}
690-
return ok
644+
return false
691645
}
692646

693647
func (r *runner) isHttpReqType(tp types.Type) bool {
694-
var ok bool
695648
for _, v := range r.httpReqTyps {
696-
if ok = types.Identical(tp, v); ok {
697-
break
649+
if ok := types.Identical(tp, v); ok {
650+
return true
698651
}
699652
}
700-
return ok
653+
return false
701654
}
702655

703656
func (r *runner) getValue(key string, f *ssa.Function) (res resInfo, ok bool) {

testdata/src/a/a.go

+15-6
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func f1(ctx context.Context) {
4848
f2(ctx)
4949
}(ctx)
5050

51-
f2(context.Background()) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead"
51+
f2(context.Background()) // want "Non-inherited new context, use function like `context.WithXXX` instead"
5252

5353
thunk := MyInt.F
5454
thunk(0)
@@ -66,7 +66,7 @@ func f3() {
6666
func f4(ctx context.Context) {
6767
f2(ctx)
6868
ctx = context.Background()
69-
f2(ctx) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead"
69+
f2(ctx) // want "Non-inherited new context, use function like `context.WithXXX` instead"
7070
}
7171

7272
func f5(ctx context.Context) {
@@ -104,19 +104,28 @@ func f9(w http.ResponseWriter, r *http.Request) {
104104
f8(context.Background(), w, r) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead"
105105
}
106106

107-
func f10() {
107+
func f10(in bool, w http.ResponseWriter, r *http.Request) {
108+
f8(r.Context(), w, r)
109+
f8(context.Background(), w, r)
110+
}
111+
112+
func f11() {
108113
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
109-
f9(w, r)
110114
f8(r.Context(), w, r)
111115
f8(context.Background(), w, r) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead"
116+
117+
f9(w, r)
118+
119+
// f10 should be like `func f10(ctx context.Context, in bool, w http.ResponseWriter, r *http.Request)`
120+
f10(true, w, r) // want "Function `f10` should pass the context parameter"
112121
})
113122
}
114123

115124
/* ----------------- generics ----------------- */
116125

117126
type MySlice[T int | float32] []T
118127

119-
func (s MySlice[T]) f11(ctx context.Context) T {
128+
func (s MySlice[T]) f12(ctx context.Context) T {
120129
f3() // generics, Block is nil, wont report
121130

122131
var sum T
@@ -126,7 +135,7 @@ func (s MySlice[T]) f11(ctx context.Context) T {
126135
return sum
127136
}
128137

129-
func f12[T int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64](ctx context.Context, a, b T) T {
138+
func f13[T int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64](ctx context.Context, a, b T) T {
130139
f3() // generics, Block is nil, wont report
131140

132141
if a > b {

0 commit comments

Comments
 (0)