Skip to content

Commit 5f38877

Browse files
feat(analyzer): recognize custom error values in return (#102)
--------- Co-authored-by: xobotyi <[email protected]>
1 parent 7ae7de4 commit 5f38877

File tree

5 files changed

+127
-25
lines changed

5 files changed

+127
-25
lines changed

README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,40 @@ var b Shape = a.Shape{
9999
Length: 5,
100100
}
101101
```
102+
103+
### Errors handling
104+
105+
In order to avoid unnecessary noise, when dealing with non-pointer types returned along with errors - `exhaustruct` will
106+
ignore non-error types, checking only structures satisfying `error` interface.
107+
108+
```go
109+
package main
110+
111+
import "errors"
112+
113+
type Shape struct {
114+
Length int
115+
Width int
116+
}
117+
118+
func NewShape() (Shape, error) {
119+
return Shape{}, errors.New("error") // will not raise an error
120+
}
121+
122+
type MyError struct {
123+
Err error
124+
}
125+
126+
func (e MyError) Error() string {
127+
return e.Err.Error()
128+
}
129+
130+
func NewSquare() (Shape, error) {
131+
return Shape{}, MyError{Err: errors.New("error")} // will not raise an error
132+
}
133+
134+
func NewCircle() (Shape, error) {
135+
return Shape{}, MyError{} // will raise "main.MyError is missing field Err"
136+
}
137+
138+
```

analyzer/analyzer.go

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,9 @@ func (a *analyzer) newVisitor(pass *analysis.Pass) func(n ast.Node, push bool, s
100100

101101
if len(lit.Elts) == 0 {
102102
if ret, ok := stackParentIsReturn(stack); ok {
103-
if returnContainsNonNilError(pass, ret) {
103+
if returnContainsNonNilError(pass, ret, n) {
104104
// it is okay to return uninitialized structure in case struct's direct parent is
105105
// a return statement containing non-nil error
106-
//
107-
// we're unable to check if returned error is custom, but at least we're able to
108-
// cover str [error] type.
109106
return true
110107
}
111108
}
@@ -184,17 +181,47 @@ func getStructType(pass *analysis.Pass, lit *ast.CompositeLit) (*types.Struct, *
184181

185182
func stackParentIsReturn(stack []ast.Node) (*ast.ReturnStmt, bool) {
186183
// it is safe to skip boundary check, since stack always has at least one element
187-
// - whole file.
188-
ret, ok := stack[len(stack)-2].(*ast.ReturnStmt)
184+
// we also have no reason to check the first element, since it is always a file
185+
for i := len(stack) - 2; i > 0; i-- {
186+
switch st := stack[i].(type) {
187+
case *ast.ReturnStmt:
188+
return st, true
189189

190-
return ret, ok
190+
case *ast.UnaryExpr:
191+
// in case we're dealing with pointers - it is still viable to check pointer's
192+
// parent for return statement
193+
continue
194+
195+
default:
196+
return nil, false
197+
}
198+
}
199+
200+
return nil, false
191201
}
192202

193-
func returnContainsNonNilError(pass *analysis.Pass, ret *ast.ReturnStmt) bool {
203+
// errorIface is a type that represents [error] interface and all types will be
204+
// compared against.
205+
var errorIface = types.Universe.Lookup("error").Type().Underlying().(*types.Interface)
206+
207+
func returnContainsNonNilError(pass *analysis.Pass, ret *ast.ReturnStmt, except ast.Node) bool {
194208
// errors are mostly located at the end of return statement, so we're starting
195209
// from the end.
196210
for i := len(ret.Results) - 1; i >= 0; i-- {
197-
if pass.TypesInfo.TypeOf(ret.Results[i]).String() == "error" {
211+
ri := ret.Results[i]
212+
213+
// skip current node
214+
if ri == except {
215+
continue
216+
}
217+
218+
if un, ok := ri.(*ast.UnaryExpr); ok {
219+
if un.X == except {
220+
continue
221+
}
222+
}
223+
224+
if types.Implements(pass.TypesInfo.TypeOf(ri), errorIface) {
198225
return true
199226
}
200227
}

analyzer/analyzer_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ func TestAnalyzer(t *testing.T) {
3333
assert.Error(t, err)
3434

3535
a, err = analyzer.NewAnalyzer(
36-
[]string{`.*[Tt]est.*`, `.*External`, `.*Embedded`, `.*\.<anonymous>`},
36+
[]string{`.*[Tt]est.*`, `.*External`, `.*Embedded`, `.*\.<anonymous>`, `j\..*Error`},
3737
[]string{`.*Excluded$`, `e\.<anonymous>`},
3838
)
3939
require.NoError(t, err)
4040

41-
analysistest.Run(t, testdataPath, a, "i", "e")
41+
analysistest.Run(t, testdataPath, a, "i", "e", "j")
4242
}

analyzer/testdata/src/i/i.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
package i
33

44
import (
5-
"errors"
6-
75
"e"
86
)
97

@@ -64,18 +62,6 @@ func shouldFailRequiredOmitted() {
6462
}
6563
}
6664

67-
func shouldPassEmptyStructWithNonNilErr() (Test, error) {
68-
return Test{}, errors.New("some error")
69-
}
70-
71-
func shouldFailEmptyStructWithNilErr() (Test, error) {
72-
return Test{}, nil // want "i.Test is missing fields A, B, C, D"
73-
}
74-
75-
func shouldFailEmptyNestedStructWithNonNilErr() ([]Test, error) {
76-
return []Test{{}}, nil // want "i.Test is missing fields A, B, C, D"
77-
}
78-
7965
func shouldPassUnnamed() {
8066
_ = []Test{{"", 0, 0.0, false, ""}}
8167
}

analyzer/testdata/src/j/j.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package j
2+
3+
import (
4+
"fmt"
5+
"os"
6+
)
7+
8+
type Test struct {
9+
A string
10+
}
11+
12+
type AError struct{}
13+
14+
func (AError) Error() string { return "error message" }
15+
16+
type BError struct{ msg string }
17+
18+
func (e BError) Error() string { return e.msg }
19+
20+
func shouldPassEmptyStructWithConcreteAError() (Test, *AError) {
21+
return Test{}, &AError{}
22+
}
23+
24+
func shouldFailEmptyStructWithEmptyBError() (Test, error) {
25+
return Test{}, &BError{} // want "j.BError is missing field msg"
26+
}
27+
28+
func shouldFailEmptyStructWithNilConcreteError() (Test, *BError) {
29+
return Test{}, nil // want "j.Test is missing field A"
30+
}
31+
32+
func shouldPassEmptyStructWithFmtError() (Test, error) {
33+
return Test{}, fmt.Errorf("error message")
34+
}
35+
36+
func shouldPassStaticError() (Test, error) {
37+
return Test{}, os.ErrNotExist
38+
}
39+
40+
func shouldPassAnonymousFunctionReturningError() (Test, error) {
41+
return Test{}, func() error { return nil }()
42+
}
43+
44+
func shouldFailAnonymousFunctionReturningEmptyError() (Test, error) {
45+
fn := func() error { return &BError{} } // want "j.BError is missing field msg"
46+
47+
return Test{}, fn()
48+
}
49+
50+
func shouldFailEmptyNestedStructWithNonNilErr() ([]Test, error) {
51+
return []Test{{}}, os.ErrNotExist // want "j.Test is missing field A"
52+
}

0 commit comments

Comments
 (0)