Skip to content

Commit c0e6cac

Browse files
committed
fix: Also apply allowed list to value switch statements
This also moves the reported AST position of switch statements from the switch keyword to the first problematic case.
1 parent e24df99 commit c0e6cac

File tree

4 files changed

+53
-26
lines changed

4 files changed

+53
-26
lines changed

errorlint/allowed.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ func isAllowedErrAndFunc(err, fun string) bool {
127127
return false
128128
}
129129

130-
func isAllowedErrorComparison(pass *TypesInfoExt, binExpr *ast.BinaryExpr) bool {
130+
func isAllowedErrorComparison(pass *TypesInfoExt, a, b ast.Expr) bool {
131131
var errName string // `<package>.<name>`, e.g. `io.EOF`
132132
var callExprs []*ast.CallExpr
133133

134134
// Figure out which half of the expression is the returned error and which
135135
// half is the presumed error declaration.
136-
for _, expr := range []ast.Expr{binExpr.X, binExpr.Y} {
136+
for _, expr := range []ast.Expr{a, b} {
137137
switch t := expr.(type) {
138138
case *ast.SelectorExpr:
139139
// A selector which we assume refers to a staticaly declared error

errorlint/lint.go

+29-20
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,15 @@ func LintErrorComparisons(info *TypesInfoExt) []analysis.Diagnostic {
172172
continue
173173
}
174174
// Comparing errors with nil is okay.
175-
if isNilComparison(binExpr) {
175+
if isNil(binExpr.X) || isNil(binExpr.Y) {
176176
continue
177177
}
178178
// Find comparisons of which one side is a of type error.
179-
if !isErrorComparison(info.TypesInfo, binExpr) {
179+
if !isErrorType(info.TypesInfo, binExpr.X) && !isErrorType(info.TypesInfo, binExpr.Y) {
180180
continue
181181
}
182182
// Some errors that are returned from some functions are exempt.
183-
if isAllowedErrorComparison(info, binExpr) {
183+
if isAllowedErrorComparison(info, binExpr.X, binExpr.Y) {
184184
continue
185185
}
186186
// Comparisons that happen in `func (type) Is(error) bool` are okay.
@@ -201,43 +201,52 @@ func LintErrorComparisons(info *TypesInfoExt) []analysis.Diagnostic {
201201
continue
202202
}
203203
// Check whether the switch operates on an error type.
204-
if switchStmt.Tag == nil {
204+
if !isErrorType(info.TypesInfo, switchStmt.Tag) {
205205
continue
206206
}
207-
tagType := info.TypesInfo.Types[switchStmt.Tag]
208-
if tagType.Type.String() != "error" {
207+
208+
var problematicCaseClause *ast.CaseClause
209+
outer:
210+
for _, stmt := range switchStmt.Body.List {
211+
caseClause := stmt.(*ast.CaseClause)
212+
for _, caseExpr := range caseClause.List {
213+
if isNil(caseExpr) {
214+
continue
215+
}
216+
// Some errors that are returned from some functions are exempt.
217+
if !isAllowedErrorComparison(info, switchStmt.Tag, caseExpr) {
218+
problematicCaseClause = caseClause
219+
break outer
220+
}
221+
}
222+
}
223+
if problematicCaseClause == nil {
209224
continue
210225
}
226+
// Comparisons that happen in `func (type) Is(error) bool` are okay.
211227
if isNodeInErrorIsFunc(info, switchStmt) {
212228
continue
213229
}
214230

215231
if switchComparesNonNil(switchStmt) {
216232
lints = append(lints, analysis.Diagnostic{
217233
Message: "switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors",
218-
Pos: switchStmt.Pos(),
234+
Pos: problematicCaseClause.Pos(),
219235
})
220236
}
221-
222237
}
223238

224239
return lints
225240
}
226241

227-
func isNilComparison(binExpr *ast.BinaryExpr) bool {
228-
if ident, ok := binExpr.X.(*ast.Ident); ok && ident.Name == "nil" {
229-
return true
230-
}
231-
if ident, ok := binExpr.Y.(*ast.Ident); ok && ident.Name == "nil" {
232-
return true
233-
}
234-
return false
242+
func isNil(ex ast.Expr) bool {
243+
ident, ok := ex.(*ast.Ident)
244+
return ok && ident.Name == "nil"
235245
}
236246

237-
func isErrorComparison(info *types.Info, binExpr *ast.BinaryExpr) bool {
238-
tx := info.Types[binExpr.X]
239-
ty := info.Types[binExpr.Y]
240-
return tx.Type.String() == "error" || ty.Type.String() == "error"
247+
func isErrorType(info *types.Info, ex ast.Expr) bool {
248+
t := info.Types[ex].Type
249+
return t != nil && t.String() == "error"
241250
}
242251

243252
func isNodeInErrorIsFunc(info *TypesInfoExt, node ast.Node) bool {

errorlint/testdata/src/errorsis/errorsis.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ func NotEqualOperatorYoda() {
7676

7777
func CompareSwitch() {
7878
err := doThing()
79-
switch err { // want `switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors`
79+
switch err {
8080
case nil:
8181
fmt.Println("nil")
82-
case ErrFoo:
82+
case ErrFoo: // want `switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors`
8383
fmt.Println("ErrFoo")
8484
}
8585
}
@@ -95,8 +95,8 @@ func CompareSwitchSafe() {
9595
}
9696

9797
func CompareSwitchInline() {
98-
switch doThing() { // want `switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors`
99-
case ErrFoo:
98+
switch doThing() {
99+
case ErrFoo: // want `switch on an error will fail on wrapped errors. Use errors.Is to check for specific errors`
100100
fmt.Println("ErrFoo")
101101
}
102102
}
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package issues
2+
3+
import (
4+
"fmt"
5+
6+
"golang.org/x/sys/unix"
7+
)
8+
9+
func SwitchOnUnixErrors() {
10+
err := unix.Rmdir("somepath")
11+
switch err {
12+
case unix.ENOENT:
13+
return
14+
case unix.EPERM:
15+
return
16+
}
17+
fmt.Println(err)
18+
}

0 commit comments

Comments
 (0)