Skip to content
This repository was archived by the owner on Jun 2, 2023. It is now read-only.

Commit db267c5

Browse files
authored
Merge pull request timakin#7 from timakin/fix/issue-1-and-3
[WIP] fix: issue: #1
2 parents 8ad5979 + 42c0f6f commit db267c5

File tree

3 files changed

+144
-15
lines changed

3 files changed

+144
-15
lines changed

passes/bodyclose/bodyclose.go

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ const (
2626
Doc = "bodyclose checks whether HTTP response body is closed successfully"
2727

2828
nethttpPath = "net/http"
29+
closeMethod = "Close"
2930
)
3031

3132
type runner struct {
@@ -70,7 +71,7 @@ func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
7071
bodyItrf := bodyNamed.Underlying().(*types.Interface)
7172
for i := 0; i < bodyItrf.NumMethods(); i++ {
7273
bmthd := bodyItrf.Method(i)
73-
if bmthd.Id() == "Close" {
74+
if bmthd.Id() == closeMethod {
7475
r.closeMthd = bmthd
7576
}
7677
}
@@ -82,6 +83,17 @@ func (r *runner) run(pass *analysis.Pass) (interface{}, error) {
8283
continue
8384
}
8485

86+
// skip if the function is just referenced
87+
var isreffunc bool
88+
for i := 0; i < f.Signature.Results().Len(); i++ {
89+
if f.Signature.Results().At(i).Type().String() == r.resTyp.String() {
90+
isreffunc = true
91+
}
92+
}
93+
if isreffunc {
94+
continue
95+
}
96+
8597
for _, b := range f.Blocks {
8698
for i := range b.Instrs {
8799
pos := b.Instrs[i].Pos()
@@ -100,6 +112,7 @@ func (r *runner) isopen(b *ssa.BasicBlock, i int) bool {
100112
if !ok {
101113
return false
102114
}
115+
103116
if len(*call.Referrers()) == 0 {
104117
return true
105118
}
@@ -126,15 +139,11 @@ func (r *runner) isopen(b *ssa.BasicBlock, i int) bool {
126139
f := c.Fn.(*ssa.Function)
127140
if r.noImportedNetHTTP(f) {
128141
// skip this
129-
continue
142+
return false
130143
}
131144
called := r.isClosureCalled(c)
132145

133-
for _, b := range f.Blocks {
134-
for i := range b.Instrs {
135-
return r.isopen(b, i) || !called
136-
}
137-
}
146+
return r.calledInFunc(f, called)
138147
}
139148

140149
}
@@ -152,6 +161,7 @@ func (r *runner) isopen(b *ssa.BasicBlock, i int) bool {
152161
}
153162

154163
bRefs := *resRef.Referrers()
164+
155165
for _, bRef := range bRefs {
156166
bOp, ok := r.getBodyOp(bRef)
157167
if !ok {
@@ -186,14 +196,17 @@ func (r *runner) getReqCall(instr ssa.Instruction) (*ssa.Call, bool) {
186196
}
187197

188198
func (r *runner) getResVal(instr ssa.Instruction) (ssa.Value, bool) {
189-
val, ok := instr.(ssa.Value)
190-
if !ok {
191-
return nil, false
192-
}
193-
if val.Type().String() != r.resTyp.String() {
194-
return nil, false
199+
switch instr := instr.(type) {
200+
case *ssa.FieldAddr:
201+
if instr.X.Type().String() == r.resTyp.String() {
202+
return instr.X.(ssa.Value), true
203+
}
204+
case ssa.Value:
205+
if instr.Type().String() == r.resTyp.String() {
206+
return instr, true
207+
}
195208
}
196-
return val, true
209+
return nil, false
197210
}
198211

199212
func (r *runner) getBodyOp(instr ssa.Instruction) (*ssa.UnOp, bool) {
@@ -217,6 +230,26 @@ func (r *runner) isCloseCall(ccall ssa.Instruction) bool {
217230
if ccall.Call.Method.Name() == r.closeMthd.Name() {
218231
return true
219232
}
233+
case *ssa.ChangeInterface:
234+
if ccall.Type().String() == "io.Closer" {
235+
closeMtd := ccall.Type().Underlying().(*types.Interface).Method(0)
236+
crs := *ccall.Referrers()
237+
for _, cs := range crs {
238+
if cs, ok := cs.(*ssa.Defer); ok {
239+
if val, ok := cs.Common().Value.(*ssa.Function); ok {
240+
for _, b := range val.Blocks {
241+
for _, instr := range b.Instrs {
242+
if c, ok := instr.(*ssa.Call); ok {
243+
if c.Call.Method == closeMtd {
244+
return true
245+
}
246+
}
247+
}
248+
}
249+
}
250+
}
251+
}
252+
}
220253
}
221254
return false
222255
}
@@ -227,7 +260,8 @@ func (r *runner) isClosureCalled(c *ssa.MakeClosure) bool {
227260
return false
228261
}
229262
for _, ref := range refs {
230-
if _, ok := ref.(*ssa.Call); ok {
263+
switch ref.(type) {
264+
case *ssa.Call, *ssa.Defer:
231265
return true
232266
}
233267
}
@@ -265,3 +299,54 @@ func (r *runner) noImportedNetHTTP(f *ssa.Function) (ret bool) {
265299

266300
return true
267301
}
302+
303+
func (r *runner) calledInFunc(f *ssa.Function, called bool) bool {
304+
for _, b := range f.Blocks {
305+
for i, instr := range b.Instrs {
306+
switch instr := instr.(type) {
307+
case *ssa.UnOp:
308+
refs := *instr.Referrers()
309+
if len(refs) == 0 {
310+
return true
311+
}
312+
for _, r := range refs {
313+
if v, ok := r.(ssa.Value); ok {
314+
if ptr, ok := v.Type().(*types.Pointer); !ok || !isNamedType(ptr.Elem(), "io", "ReadCloser") {
315+
return true
316+
}
317+
vrefs := *v.Referrers()
318+
for _, vref := range vrefs {
319+
if vref, ok := vref.(*ssa.UnOp); ok {
320+
vrefs := *vref.Referrers()
321+
if len(vrefs) == 0 {
322+
return true
323+
}
324+
for _, vref := range vrefs {
325+
if c, ok := vref.(*ssa.Call); ok {
326+
if c.Call.Method.Name() == closeMethod {
327+
return !called
328+
}
329+
}
330+
}
331+
}
332+
}
333+
}
334+
335+
}
336+
default:
337+
return r.isopen(b, i) || !called
338+
}
339+
}
340+
}
341+
return false
342+
}
343+
344+
// isNamedType reports whether t is the named type path.name.
345+
func isNamedType(t types.Type, path, name string) bool {
346+
n, ok := t.(*types.Named)
347+
if !ok {
348+
return false
349+
}
350+
obj := n.Obj()
351+
return obj.Name() == name && obj.Pkg() != nil && obj.Pkg().Path() == path
352+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package a
2+
3+
import "net/http"
4+
5+
func get() *http.Response {
6+
resp, _ := http.Get("https://example.com")
7+
return resp
8+
}
9+
10+
func main() {
11+
resp := get()
12+
resp.Body.Close()
13+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package a
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
)
8+
9+
func closeBody(c io.Closer) {
10+
_ = c.Close()
11+
}
12+
13+
func issue3_1() {
14+
resp, _ := http.Get("https://example.com")
15+
defer closeBody(resp.Body)
16+
}
17+
18+
func issue3_2() {
19+
resp, _ := http.Get("https://example.com")
20+
defer func() {
21+
_ = resp.Body.Close()
22+
}()
23+
}
24+
25+
func issue3_3() {
26+
resp, err := http.DefaultClient.Do(nil)
27+
if err != nil {
28+
// handle err
29+
}
30+
defer func() { fmt.Println(resp.Body.Close()) }()
31+
}

0 commit comments

Comments
 (0)