Skip to content

Commit ed710f4

Browse files
committed
pattern: guard against trying to match mismatched types
For example, trying to match an *ast.CallExpr against an []ast.Stmt would fall into the branch handling []ast.Stmt, which mustn't assume that the lhs is an ast.Stmt.
1 parent 6072793 commit ed710f4

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

pattern/match.go

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,24 @@ func match(m *Matcher, l, r interface{}) (interface{}, bool) {
225225
}
226226
}
227227

228+
// TODO(dh): the three blocks handling slices can be combined into a single block if we use reflection
229+
228230
{
229231
ln, ok1 := l.([]ast.Expr)
230232
rn, ok2 := r.([]ast.Expr)
231233
if ok1 || ok2 {
232234
if ok1 && !ok2 {
233-
rn = []ast.Expr{r.(ast.Expr)}
235+
cast, ok := r.(ast.Expr)
236+
if !ok {
237+
return nil, false
238+
}
239+
rn = []ast.Expr{cast}
234240
} else if !ok1 && ok2 {
235-
ln = []ast.Expr{l.(ast.Expr)}
241+
cast, ok := l.(ast.Expr)
242+
if !ok {
243+
return nil, false
244+
}
245+
ln = []ast.Expr{cast}
236246
}
237247

238248
if len(ln) != len(rn) {
@@ -252,9 +262,17 @@ func match(m *Matcher, l, r interface{}) (interface{}, bool) {
252262
rn, ok2 := r.([]ast.Stmt)
253263
if ok1 || ok2 {
254264
if ok1 && !ok2 {
255-
rn = []ast.Stmt{r.(ast.Stmt)}
265+
cast, ok := r.(ast.Stmt)
266+
if !ok {
267+
return nil, false
268+
}
269+
rn = []ast.Stmt{cast}
256270
} else if !ok1 && ok2 {
257-
ln = []ast.Stmt{l.(ast.Stmt)}
271+
cast, ok := l.(ast.Stmt)
272+
if !ok {
273+
return nil, false
274+
}
275+
ln = []ast.Stmt{cast}
258276
}
259277

260278
if len(ln) != len(rn) {
@@ -274,9 +292,17 @@ func match(m *Matcher, l, r interface{}) (interface{}, bool) {
274292
rn, ok2 := r.([]*ast.Field)
275293
if ok1 || ok2 {
276294
if ok1 && !ok2 {
277-
rn = []*ast.Field{r.(*ast.Field)}
295+
cast, ok := r.(*ast.Field)
296+
if !ok {
297+
return nil, false
298+
}
299+
rn = []*ast.Field{cast}
278300
} else if !ok1 && ok2 {
279-
ln = []*ast.Field{l.(*ast.Field)}
301+
cast, ok := l.(*ast.Field)
302+
if !ok {
303+
return nil, false
304+
}
305+
ln = []*ast.Field{cast}
280306
}
281307

282308
if len(ln) != len(rn) {

0 commit comments

Comments
 (0)