Skip to content

Commit 6c750dd

Browse files
committed
Add default-signifies-exhasutive flag to prevent default clause from automatically passing checks
1 parent 187668c commit 6c750dd

File tree

6 files changed

+71
-21
lines changed

6 files changed

+71
-21
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ mysumtype.go:18:2: exhaustiveness check failed for sum type 'MySumType': missing
8686
```
8787

8888
Adding either a `default` clause or a clause to handle `*VariantB` will cause
89-
exhaustive checks to pass.
89+
exhaustive checks to pass. To prevent `default` clauses from automatically
90+
passing checks, set the `-default-signifies-exhasutive=false` flag.
9091

9192
As a special case, if the type switch statement contains a `default` clause
9293
that always panics, then exhaustiveness checks are still performed.

check.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ func (e inexhaustiveError) Names() []string {
3939

4040
// check does exhaustiveness checking for the given sum type definitions in the
4141
// given package. Every instance of inexhaustive case analysis is returned.
42-
func check(pkg *packages.Package, defs []sumTypeDef) []error {
42+
func check(pkg *packages.Package, defs []sumTypeDef, config Config) []error {
4343
var errs []error
4444
for _, astfile := range pkg.Syntax {
4545
ast.Inspect(astfile, func(n ast.Node) bool {
4646
swtch, ok := n.(*ast.TypeSwitchStmt)
4747
if !ok {
4848
return true
4949
}
50-
if err := checkSwitch(pkg, defs, swtch); err != nil {
50+
if err := checkSwitch(pkg, defs, swtch, config); err != nil {
5151
errs = append(errs, err)
5252
}
5353
return true
@@ -67,8 +67,9 @@ func checkSwitch(
6767
pkg *packages.Package,
6868
defs []sumTypeDef,
6969
swtch *ast.TypeSwitchStmt,
70+
config Config,
7071
) error {
71-
def, missing := missingVariantsInSwitch(pkg, defs, swtch)
72+
def, missing := missingVariantsInSwitch(pkg, defs, swtch, config)
7273
if len(missing) > 0 {
7374
return inexhaustiveError{
7475
Position: pkg.Fset.Position(swtch.Pos()),
@@ -87,6 +88,7 @@ func missingVariantsInSwitch(
8788
pkg *packages.Package,
8889
defs []sumTypeDef,
8990
swtch *ast.TypeSwitchStmt,
91+
config Config,
9092
) (*sumTypeDef, []types.Object) {
9193
asserted := findTypeAssertExpr(swtch)
9294
ty := pkg.TypesInfo.TypeOf(asserted)
@@ -97,7 +99,7 @@ func missingVariantsInSwitch(
9799
return nil, nil
98100
}
99101
variantExprs, hasDefault := switchVariants(swtch)
100-
if hasDefault && !defaultClauseAlwaysPanics(swtch) {
102+
if config.DefaultSignifiesExhaustive && hasDefault && !defaultClauseAlwaysPanics(swtch) {
101103
// A catch-all case defeats all exhaustiveness checks.
102104
return def, nil
103105
}

check_test.go

+42-11
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ func main() {
2929
tmpdir, pkgs := setupPackages(t, code)
3030
defer teardownPackage(t, tmpdir)
3131

32-
errs := Run(pkgs)
32+
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
3333
assert.Equal(t, 1, len(errs))
3434
assert.Equal(t, []string{"B"}, missingNames(t, errs[0]))
3535
}
3636

37-
// TestMissingTwo tests that we detect a two missing variants.
37+
// TestMissingTwo tests that we detect two missing variants.
3838
func TestMissingTwo(t *testing.T) {
3939
code := `
4040
package gochecksumtype
@@ -60,7 +60,7 @@ func main() {
6060
tmpdir, pkgs := setupPackages(t, code)
6161
defer teardownPackage(t, tmpdir)
6262

63-
errs := Run(pkgs)
63+
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
6464
assert.Equal(t, 1, len(errs))
6565
assert.Equal(t, []string{"B", "C"}, missingNames(t, errs[0]))
6666
}
@@ -91,7 +91,7 @@ func main() {
9191
tmpdir, pkgs := setupPackages(t, code)
9292
defer teardownPackage(t, tmpdir)
9393

94-
errs := Run(pkgs)
94+
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
9595
assert.Equal(t, 1, len(errs))
9696
assert.Equal(t, []string{"B"}, missingNames(t, errs[0]))
9797
}
@@ -122,13 +122,13 @@ func main() {
122122
tmpdir, pkgs := setupPackages(t, code)
123123
defer teardownPackage(t, tmpdir)
124124

125-
errs := Run(pkgs)
125+
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
126126
assert.Equal(t, 0, len(errs))
127127
}
128128

129-
// TestNoMissingDefault tests that even if we have a missing variant, a default
130-
// case should thwart exhaustiveness checking.
131-
func TestNoMissingDefault(t *testing.T) {
129+
// TestNoMissingDefaultWithDefaultSignifiesExhaustive tests that even if we have a missing variant, a default
130+
// case should thwart exhaustiveness checking when Config.DefaultSignifiesExhaustive is true.
131+
func TestNoMissingDefaultWithDefaultSignifiesExhaustive(t *testing.T) {
132132
code := `
133133
package gochecksumtype
134134
@@ -152,10 +152,41 @@ func main() {
152152
tmpdir, pkgs := setupPackages(t, code)
153153
defer teardownPackage(t, tmpdir)
154154

155-
errs := Run(pkgs)
155+
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
156156
assert.Equal(t, 0, len(errs))
157157
}
158158

159+
// TestNoMissingDefaultAndDefaultDoesNotSignifiesExhaustive tests that even if we have a missing variant, a default
160+
// case should thwart exhaustiveness checking when Config.DefaultSignifiesExhaustive is false.
161+
func TestNoMissingDefaultAndDefaultDoesNotSignifiesExhaustive(t *testing.T) {
162+
code := `
163+
package gochecksumtype
164+
165+
//sumtype:decl
166+
type T interface { sealed() }
167+
168+
type A struct {}
169+
func (a *A) sealed() {}
170+
171+
type B struct {}
172+
func (b *B) sealed() {}
173+
174+
func main() {
175+
switch T(nil).(type) {
176+
case *A:
177+
default:
178+
println("legit catch all goes here")
179+
}
180+
}
181+
`
182+
tmpdir, pkgs := setupPackages(t, code)
183+
defer teardownPackage(t, tmpdir)
184+
185+
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: false})
186+
assert.Equal(t, 1, len(errs))
187+
assert.Equal(t, []string{"B"}, missingNames(t, errs[0]))
188+
}
189+
159190
// TestNotSealed tests that we report an error if one tries to declare a sum
160191
// type with an unsealed interface.
161192
func TestNotSealed(t *testing.T) {
@@ -170,7 +201,7 @@ func main() {}
170201
tmpdir, pkgs := setupPackages(t, code)
171202
defer teardownPackage(t, tmpdir)
172203

173-
errs := Run(pkgs)
204+
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
174205
assert.Equal(t, 1, len(errs))
175206
assert.Equal(t, "T", errs[0].(unsealedError).Decl.TypeName)
176207
}
@@ -189,7 +220,7 @@ func main() {}
189220
tmpdir, pkgs := setupPackages(t, code)
190221
defer teardownPackage(t, tmpdir)
191222

192-
errs := Run(pkgs)
223+
errs := Run(pkgs, Config{DefaultSignifiesExhaustive: true})
193224
assert.Equal(t, 1, len(errs))
194225
assert.Equal(t, "T", errs[0].(notInterfaceError).Decl.TypeName)
195226
}

cmd/go-check-sumtype/main.go

+14-3
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,22 @@ import (
1212

1313
func main() {
1414
log.SetFlags(0)
15+
16+
defaultSignifiesExhaustive := flag.Bool(
17+
"default-signifies-exhaustive",
18+
true,
19+
"Presence of \"default\" case in switch statements satisfies exhaustiveness, if all members are not listed.",
20+
)
21+
1522
flag.Parse()
16-
if len(flag.Args()) < 1 {
23+
if flag.NArg() < 1 {
1724
log.Fatalf("Usage: sumtype <packages>\n")
1825
}
19-
args := os.Args[1:]
26+
args := os.Args[flag.NFlag()+1:]
27+
28+
config := gochecksumtype.Config{
29+
DefaultSignifiesExhaustive: *defaultSignifiesExhaustive,
30+
}
2031

2132
conf := &packages.Config{
2233
Mode: packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedTypes | packages.NeedTypesSizes |
@@ -37,7 +48,7 @@ func main() {
3748
if err != nil {
3849
log.Fatal(err)
3950
}
40-
if errs := gochecksumtype.Run(pkgs); len(errs) > 0 {
51+
if errs := gochecksumtype.Run(pkgs, config); len(errs) > 0 {
4152
var list []string
4253
for _, err := range errs {
4354
list = append(list, err.Error())

config.go

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package gochecksumtype
2+
3+
type Config struct {
4+
DefaultSignifiesExhaustive bool
5+
}

run.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package gochecksumtype
33
import "golang.org/x/tools/go/packages"
44

55
// Run sumtype checking on the given packages.
6-
func Run(pkgs []*packages.Package) []error {
6+
func Run(pkgs []*packages.Package, config Config) []error {
77
var errs []error
88

99
decls, err := findSumTypeDecls(pkgs)
@@ -18,7 +18,7 @@ func Run(pkgs []*packages.Package) []error {
1818
}
1919

2020
for _, pkg := range pkgs {
21-
if pkgErrs := check(pkg, defs); pkgErrs != nil {
21+
if pkgErrs := check(pkg, defs, config); pkgErrs != nil {
2222
errs = append(errs, pkgErrs...)
2323
}
2424
}

0 commit comments

Comments
 (0)