Skip to content

Commit bf7feda

Browse files
authored
fix: correctly identify infixed concats as potential SQL injections (#987)
1 parent 2292ed5 commit bf7feda

File tree

3 files changed

+142
-15
lines changed

3 files changed

+142
-15
lines changed

helpers.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,46 @@ func GetChar(n ast.Node) (byte, error) {
9696
return 0, fmt.Errorf("Unexpected AST node type: %T", n)
9797
}
9898

99+
// GetStringRecursive will recursively walk down a tree of *ast.BinaryExpr. It will then concat the results, and return.
100+
// Unlike the other getters, it does _not_ raise an error for unknown ast.Node types. At the base, the recursion will hit a non-BinaryExpr type,
101+
// either BasicLit or other, so it's not an error case. It will only error if `strconv.Unquote` errors. This matters, because there's
102+
// currently functionality that relies on error values being returned by GetString if and when it hits a non-basiclit string node type,
103+
// hence for cases where recursion is needed, we use this separate function, so that we can still be backwards compatbile.
104+
//
105+
// This was added to handle a SQL injection concatenation case where the injected value is infixed between two strings, not at the start or end. See example below
106+
//
107+
// Do note that this will omit non-string values. So for example, if you were to use this node:
108+
// ```go
109+
// q := "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1" // will result in "SELECT * FROM foo WHERE ” AND 1=1"
110+
111+
func GetStringRecursive(n ast.Node) (string, error) {
112+
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
113+
return strconv.Unquote(node.Value)
114+
}
115+
116+
if expr, ok := n.(*ast.BinaryExpr); ok {
117+
x, err := GetStringRecursive(expr.X)
118+
if err != nil {
119+
return "", err
120+
}
121+
122+
y, err := GetStringRecursive(expr.Y)
123+
if err != nil {
124+
return "", err
125+
}
126+
127+
return x + y, nil
128+
}
129+
130+
return "", nil
131+
}
132+
99133
// GetString will read and return a string value from an ast.BasicLit
100134
func GetString(n ast.Node) (string, error) {
101135
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
102136
return strconv.Unquote(node.Value)
103137
}
138+
104139
return "", fmt.Errorf("Unexpected AST node type: %T", n)
105140
}
106141

@@ -201,22 +236,21 @@ func GetCallStringArgsValues(n ast.Node, _ *Context) []string {
201236
return values
202237
}
203238

204-
// GetIdentStringValues return the string values of an Ident if they can be resolved
205-
func GetIdentStringValues(ident *ast.Ident) []string {
239+
func getIdentStringValues(ident *ast.Ident, stringFinder func(ast.Node) (string, error)) []string {
206240
values := []string{}
207241
obj := ident.Obj
208242
if obj != nil {
209243
switch decl := obj.Decl.(type) {
210244
case *ast.ValueSpec:
211245
for _, v := range decl.Values {
212-
value, err := GetString(v)
246+
value, err := stringFinder(v)
213247
if err == nil {
214248
values = append(values, value)
215249
}
216250
}
217251
case *ast.AssignStmt:
218252
for _, v := range decl.Rhs {
219-
value, err := GetString(v)
253+
value, err := stringFinder(v)
220254
if err == nil {
221255
values = append(values, value)
222256
}
@@ -226,6 +260,18 @@ func GetIdentStringValues(ident *ast.Ident) []string {
226260
return values
227261
}
228262

263+
// getIdentStringRecursive returns the string of values of an Ident if they can be resolved
264+
// The difference between this and GetIdentStringValues is that it will attempt to resolve the strings recursively,
265+
// if it is passed a *ast.BinaryExpr. See GetStringRecursive for details
266+
func GetIdentStringValuesRecursive(ident *ast.Ident) []string {
267+
return getIdentStringValues(ident, GetStringRecursive)
268+
}
269+
270+
// GetIdentStringValues return the string values of an Ident if they can be resolved
271+
func GetIdentStringValues(ident *ast.Ident) []string {
272+
return getIdentStringValues(ident, GetString)
273+
}
274+
229275
// GetBinaryExprOperands returns all operands of a binary expression by traversing
230276
// the expression tree
231277
func GetBinaryExprOperands(be *ast.BinaryExpr) []ast.Node {

rules/sql.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,32 @@ func (s *sqlStrConcat) ID() string {
9898
return s.MetaData.ID
9999
}
100100

101+
// findInjectionInBranch walks diwb a set if expressions, and will create new issues if it finds SQL injections
102+
// This method assumes you've already verified that the branch contains SQL syntax
103+
func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Expr) *ast.BinaryExpr {
104+
for _, node := range branch {
105+
be, ok := node.(*ast.BinaryExpr)
106+
if !ok {
107+
continue
108+
}
109+
110+
operands := gosec.GetBinaryExprOperands(be)
111+
112+
for _, op := range operands {
113+
if _, ok := op.(*ast.BasicLit); ok {
114+
continue
115+
}
116+
117+
if ident, ok := op.(*ast.Ident); ok && s.checkObject(ident, ctx) {
118+
continue
119+
}
120+
121+
return be
122+
}
123+
}
124+
return nil
125+
}
126+
101127
// see if we can figure out what it is
102128
func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
103129
if n.Obj != nil {
@@ -140,6 +166,28 @@ func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issu
140166
}
141167
}
142168

169+
// Handle the case where an injection occurs as an infixed string concatenation, ie "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1"
170+
if id, ok := query.(*ast.Ident); ok {
171+
var match bool
172+
for _, str := range gosec.GetIdentStringValuesRecursive(id) {
173+
if s.MatchPatterns(str) {
174+
match = true
175+
break
176+
}
177+
}
178+
179+
if !match {
180+
return nil, nil
181+
}
182+
183+
switch decl := id.Obj.Decl.(type) {
184+
case *ast.AssignStmt:
185+
if injection := s.findInjectionInBranch(ctx, decl.Rhs); injection != nil {
186+
return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil
187+
}
188+
}
189+
}
190+
143191
return nil, nil
144192
}
145193

@@ -157,6 +205,7 @@ func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, erro
157205
return s.checkQuery(sqlQueryCall, ctx)
158206
}
159207
}
208+
160209
return nil, nil
161210
}
162211

@@ -165,7 +214,7 @@ func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) {
165214
rule := &sqlStrConcat{
166215
sqlStatement: sqlStatement{
167216
patterns: []*regexp.Regexp{
168-
regexp.MustCompile(`(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
217+
regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE)( |\n|\r|\t)"),
169218
},
170219
MetaData: issue.MetaData{
171220
ID: id,

testutils/source.go

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,6 +1712,28 @@ func main() {
17121712
// SampleCodeG202 - SQL query string building via string concatenation
17131713
SampleCodeG202 = []CodeSample{
17141714
{[]string{`
1715+
// infixed concatenation
1716+
package main
1717+
1718+
import (
1719+
"database/sql"
1720+
"os"
1721+
)
1722+
1723+
func main(){
1724+
db, err := sql.Open("sqlite3", ":memory:")
1725+
if err != nil {
1726+
panic(err)
1727+
}
1728+
1729+
q := "INSERT INTO foo (name) VALUES ('" + os.Args[0] + "')"
1730+
rows, err := db.Query(q)
1731+
if err != nil {
1732+
panic(err)
1733+
}
1734+
defer rows.Close()
1735+
}`}, 1, gosec.NewConfig()},
1736+
{[]string{`
17151737
package main
17161738
17171739
import (
@@ -1729,7 +1751,8 @@ func main(){
17291751
panic(err)
17301752
}
17311753
defer rows.Close()
1732-
}`}, 1, gosec.NewConfig()}, {[]string{`
1754+
}`}, 1, gosec.NewConfig()},
1755+
{[]string{`
17331756
// case insensitive match
17341757
package main
17351758
@@ -1748,7 +1771,8 @@ func main(){
17481771
panic(err)
17491772
}
17501773
defer rows.Close()
1751-
}`}, 1, gosec.NewConfig()}, {[]string{`
1774+
}`}, 1, gosec.NewConfig()},
1775+
{[]string{`
17521776
// context match
17531777
package main
17541778
@@ -1768,7 +1792,8 @@ func main(){
17681792
panic(err)
17691793
}
17701794
defer rows.Close()
1771-
}`}, 1, gosec.NewConfig()}, {[]string{`
1795+
}`}, 1, gosec.NewConfig()},
1796+
{[]string{`
17721797
// DB transaction check
17731798
package main
17741799
@@ -1796,7 +1821,8 @@ func main(){
17961821
if err := tx.Commit(); err != nil {
17971822
panic(err)
17981823
}
1799-
}`}, 1, gosec.NewConfig()}, {[]string{`
1824+
}`}, 1, gosec.NewConfig()},
1825+
{[]string{`
18001826
// multiple string concatenation
18011827
package main
18021828
@@ -1815,7 +1841,8 @@ func main(){
18151841
panic(err)
18161842
}
18171843
defer rows.Close()
1818-
}`}, 1, gosec.NewConfig()}, {[]string{`
1844+
}`}, 1, gosec.NewConfig()},
1845+
{[]string{`
18191846
// false positive
18201847
package main
18211848
@@ -1834,7 +1861,8 @@ func main(){
18341861
panic(err)
18351862
}
18361863
defer rows.Close()
1837-
}`}, 0, gosec.NewConfig()}, {[]string{`
1864+
}`}, 0, gosec.NewConfig()},
1865+
{[]string{`
18381866
package main
18391867
18401868
import (
@@ -1856,7 +1884,8 @@ func main(){
18561884
}
18571885
defer rows.Close()
18581886
}
1859-
`}, 0, gosec.NewConfig()}, {[]string{`
1887+
`}, 0, gosec.NewConfig()},
1888+
{[]string{`
18601889
package main
18611890
18621891
const gender = "M"
@@ -1882,7 +1911,8 @@ func main(){
18821911
}
18831912
defer rows.Close()
18841913
}
1885-
`}, 0, gosec.NewConfig()}, {[]string{`
1914+
`}, 0, gosec.NewConfig()},
1915+
{[]string{`
18861916
// ExecContext match
18871917
package main
18881918
@@ -1903,7 +1933,8 @@ func main() {
19031933
panic(err)
19041934
}
19051935
fmt.Println(result)
1906-
}`}, 1, gosec.NewConfig()}, {[]string{`
1936+
}`}, 1, gosec.NewConfig()},
1937+
{[]string{`
19071938
// Exec match
19081939
package main
19091940
@@ -1923,7 +1954,8 @@ func main() {
19231954
panic(err)
19241955
}
19251956
fmt.Println(result)
1926-
}`}, 1, gosec.NewConfig()}, {[]string{`
1957+
}`}, 1, gosec.NewConfig()},
1958+
{[]string{`
19271959
package main
19281960
19291961
import (

0 commit comments

Comments
 (0)