Skip to content

Commit 6a26c23

Browse files
authored
Refactor SQL rules for better extensibility (#841)
Remove hardwired assumption and heuristics on index of arg taking a SQL string, be explicit about it instead.
1 parent 1b0873a commit 6a26c23

File tree

1 file changed

+58
-19
lines changed

1 file changed

+58
-19
lines changed

rules/sql.go

+58-19
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
package rules
1616

1717
import (
18+
"fmt"
1819
"go/ast"
1920
"regexp"
20-
"strings"
2121

2222
"github.com/securego/gosec/v2"
2323
)
@@ -30,6 +30,51 @@ type sqlStatement struct {
3030
patterns []*regexp.Regexp
3131
}
3232

33+
var sqlCallIdents = map[string]map[string]int{
34+
"*database/sql.DB": {
35+
"Exec": 0,
36+
"ExecContext": 1,
37+
"Query": 0,
38+
"QueryContext": 1,
39+
"QueryRow": 0,
40+
"QueryRowContext": 1,
41+
"Prepare": 0,
42+
"PrepareContext": 1,
43+
},
44+
"*database/sql.Tx": {
45+
"Exec": 0,
46+
"ExecContext": 1,
47+
"Query": 0,
48+
"QueryContext": 1,
49+
"QueryRow": 0,
50+
"QueryRowContext": 1,
51+
"Prepare": 0,
52+
"PrepareContext": 1,
53+
},
54+
}
55+
56+
// findQueryArg locates the argument taking raw SQL
57+
func findQueryArg(call *ast.CallExpr, ctx *gosec.Context) (ast.Expr, error) {
58+
typeName, fnName, err := gosec.GetCallInfo(call, ctx)
59+
if err != nil {
60+
return nil, err
61+
}
62+
i := -1
63+
if ni, ok := sqlCallIdents[typeName]; ok {
64+
if i, ok = ni[fnName]; !ok {
65+
i = -1
66+
}
67+
}
68+
if i == -1 {
69+
return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName)
70+
}
71+
if i >= len(call.Args) {
72+
return nil, nil
73+
}
74+
query := call.Args[i]
75+
return query, nil
76+
}
77+
3378
func (s *sqlStatement) ID() string {
3479
return s.MetaData.ID
3580
}
@@ -69,16 +114,10 @@ func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
69114

70115
// checkQuery verifies if the query parameters is a string concatenation
71116
func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
72-
_, fnName, err := gosec.GetCallInfo(call, ctx)
117+
query, err := findQueryArg(call, ctx)
73118
if err != nil {
74119
return nil, err
75120
}
76-
var query ast.Node
77-
if strings.HasSuffix(fnName, "Context") {
78-
query = call.Args[1]
79-
} else {
80-
query = call.Args[0]
81-
}
82121

83122
if be, ok := query.(*ast.BinaryExpr); ok {
84123
operands := gosec.GetBinaryExprOperands(be)
@@ -137,8 +176,11 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
137176
},
138177
}
139178

140-
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
141-
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
179+
for s, si := range sqlCallIdents {
180+
for i := range si {
181+
rule.Add(s, i)
182+
}
183+
}
142184
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
143185
}
144186

@@ -171,16 +213,10 @@ func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool {
171213
}
172214

173215
func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
174-
_, fnName, err := gosec.GetCallInfo(call, ctx)
216+
query, err := findQueryArg(call, ctx)
175217
if err != nil {
176218
return nil, err
177219
}
178-
var query ast.Node
179-
if strings.HasSuffix(fnName, "Context") {
180-
query = call.Args[1]
181-
} else {
182-
query = call.Args[0]
183-
}
184220

185221
if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil {
186222
decl := ident.Obj.Decl
@@ -306,8 +342,11 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
306342
},
307343
},
308344
}
309-
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
310-
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
345+
for s, si := range sqlCallIdents {
346+
for i := range si {
347+
rule.Add(s, i)
348+
}
349+
}
311350
rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
312351
rule.noIssue.AddAll("os", "Stdout", "Stderr")
313352
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")

0 commit comments

Comments
 (0)