-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Placeholder interpolation #309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 30 commits
e35fa00
c8c9bb1
cac6129
b7c2c47
f3b82fd
058ce87
42956fa
e6bf23a
b473259
42a1efd
3c8fa90
6c8484b
04866ee
dd7b87c
9faabe5
468b9e5
0297315
0b75396
9f84dfb
8826242
916a1f2
88aeb98
43536c7
0c7ae46
c285e39
fcea447
bfbe6c5
d65f96a
e11c825
b4f0315
1fd0514
20b75cd
e517683
0f22bc2
52a5860
2a634df
90cb6c3
9437b61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
language: go | ||
go: | ||
- 1.1 | ||
- 1.2 | ||
- 1.3 | ||
- 1.4 | ||
- tip | ||
|
||
before_script: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,6 +182,18 @@ SELECT u.id FROM users as u | |
|
||
will return `u.id` instead of just `id` if `columnsWithAlias=true`. | ||
|
||
##### `interpolateParams` | ||
|
||
``` | ||
Type: bool | ||
Valid Values: true, false | ||
Default: false | ||
``` | ||
|
||
When `interpolateParams` is true, calls to `sql.Db.Query()` and `sql.Db.Exec()` with params interpolates placeholders (`?`) with given params. This reduces roundtrips to database compared with `interpolateParams=false` since it uses prepare, exec and close to support parameters. | ||
|
||
NOTE: It make SQL injection vulnerability when connection encoding is multibyte encoding except utf-8 (e.g. cp932). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote this in an outdated diff and just copied it here: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It shouldn't just fall back. It should return an error! This could be checked already during the initialization phase. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I don't get it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @arnehormann It's OK to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi guys, I really would like to see a test case for this as it's not 100% obvious for me how this vulnerability applies to this go mysql driver. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here is an explanation by @xaprb: #157 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @arvenil Here is collation vs charset example:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's pretty simple, really. A collation just tells you if |
||
|
||
##### `loc` | ||
|
||
``` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ import ( | |
"database/sql/driver" | ||
"errors" | ||
"net" | ||
"strconv" | ||
"strings" | ||
"time" | ||
) | ||
|
@@ -26,6 +27,7 @@ type mysqlConn struct { | |
maxPacketAllowed int | ||
maxWriteSize int | ||
flags clientFlag | ||
status statusFlag | ||
sequence uint8 | ||
parseTime bool | ||
strict bool | ||
|
@@ -46,6 +48,7 @@ type config struct { | |
allowOldPasswords bool | ||
clientFoundRows bool | ||
columnsWithAlias bool | ||
interpolateParams bool | ||
} | ||
|
||
// Handles parameters set in DSN after the connection is established | ||
|
@@ -162,28 +165,174 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { | |
return stmt, err | ||
} | ||
|
||
// estimateParamLength calculates upper bound of string length from types. | ||
func estimateParamLength(args []driver.Value) (int, bool) { | ||
l := 0 | ||
for _, a := range args { | ||
switch v := a.(type) { | ||
case int64, float64: | ||
// 24 (-1.7976931348623157e+308) may be upper bound. But I'm not sure. | ||
l += 25 | ||
case bool: | ||
l += 1 // 0 or 1 | ||
case time.Time: | ||
l += 30 // '1234-12-23 12:34:56.777777' | ||
case string: | ||
l += len(v)*2 + 2 | ||
case []byte: | ||
l += len(v)*2 + 2 | ||
default: | ||
return 0, false | ||
} | ||
} | ||
return l, true | ||
} | ||
|
||
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { | ||
estimated, ok := estimateParamLength(args) | ||
if !ok { | ||
return "", driver.ErrSkip | ||
} | ||
estimated += len(query) | ||
|
||
buf := make([]byte, 0, estimated) | ||
argPos := 0 | ||
|
||
for i := 0; i < len(query); i++ { | ||
q := strings.IndexByte(query[i:], '?') | ||
if q == -1 { | ||
buf = append(buf, query[i:]...) | ||
break | ||
} | ||
buf = append(buf, query[i:i+q]...) | ||
i += q | ||
|
||
arg := args[argPos] | ||
argPos++ | ||
|
||
if arg == nil { | ||
buf = append(buf, "NULL"...) | ||
continue | ||
} | ||
|
||
switch v := arg.(type) { | ||
case int64: | ||
buf = strconv.AppendInt(buf, v, 10) | ||
case float64: | ||
buf = strconv.AppendFloat(buf, v, 'g', -1, 64) | ||
case bool: | ||
if v { | ||
buf = append(buf, '1') | ||
} else { | ||
buf = append(buf, '0') | ||
} | ||
case time.Time: | ||
if v.IsZero() { | ||
buf = append(buf, "'0000-00-00'"...) | ||
} else { | ||
v := v.In(mc.cfg.loc) | ||
v = v.Add(time.Nanosecond * 500) // To round under microsecond | ||
year := v.Year() | ||
year100 := year / 100 | ||
year1 := year % 100 | ||
month := v.Month() | ||
day := v.Day() | ||
hour := v.Hour() | ||
minute := v.Minute() | ||
second := v.Second() | ||
micro := v.Nanosecond() / 1000 | ||
|
||
buf = append(buf, []byte{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we benchmarked and checked https://github.com/go-sql-driver/mysql/blob/master/utils.go#L535 pretty extensively. Maybe you can get inspiration from there, especially the const digits10 and digits01 part without division and modulus was faster than this approach. Maybe the code can be unified (I have doubts). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've did. |
||
'\'', | ||
digits10[year100], digits01[year100], | ||
digits10[year1], digits01[year1], | ||
'-', | ||
digits10[month], digits01[month], | ||
'-', | ||
digits10[day], digits01[day], | ||
' ', | ||
digits10[hour], digits01[hour], | ||
':', | ||
digits10[minute], digits01[minute], | ||
':', | ||
digits10[second], digits01[second], | ||
}...) | ||
|
||
if micro != 0 { | ||
micro10000 := micro / 10000 | ||
micro100 := micro / 100 % 100 | ||
micro1 := micro % 100 | ||
buf = append(buf, []byte{ | ||
'.', | ||
digits10[micro10000], digits01[micro10000], | ||
digits10[micro100], digits01[micro100], | ||
digits10[micro1], digits01[micro1], | ||
}...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MySQL 5.6 rounds under 1 micro second. |
||
} | ||
buf = append(buf, '\'') | ||
} | ||
case []byte: | ||
if v == nil { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needed, handled in L191 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Required since |
||
buf = append(buf, "NULL"...) | ||
} else { | ||
buf = append(buf, '\'') | ||
if mc.status&statusNoBackslashEscapes == 0 { | ||
buf = escapeBytesBackslash(buf, v) | ||
} else { | ||
buf = escapeBytesQuotes(buf, v) | ||
} | ||
buf = append(buf, '\'') | ||
} | ||
case string: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When []byte and string are essentially the same code, can they be in one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I don't know that I can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't... |
||
buf = append(buf, '\'') | ||
if mc.status&statusNoBackslashEscapes == 0 { | ||
buf = escapeStringBackslash(buf, v) | ||
} else { | ||
buf = escapeStringQuotes(buf, v) | ||
} | ||
buf = append(buf, '\'') | ||
default: | ||
return "", driver.ErrSkip | ||
} | ||
|
||
if len(buf)+4 > mc.maxPacketAllowed { | ||
return "", driver.ErrSkip | ||
} | ||
} | ||
if argPos != len(args) { | ||
return "", driver.ErrSkip | ||
} | ||
return string(buf), nil | ||
} | ||
|
||
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { | ||
if mc.netConn == nil { | ||
errLog.Print(ErrInvalidConn) | ||
return nil, driver.ErrBadConn | ||
} | ||
if len(args) == 0 { // no args, fastpath | ||
mc.affectedRows = 0 | ||
mc.insertId = 0 | ||
|
||
err := mc.exec(query) | ||
if err == nil { | ||
return &mysqlResult{ | ||
affectedRows: int64(mc.affectedRows), | ||
insertId: int64(mc.insertId), | ||
}, err | ||
if len(args) != 0 { | ||
if !mc.cfg.interpolateParams { | ||
return nil, driver.ErrSkip | ||
} | ||
return nil, err | ||
// try client-side prepare to reduce roundtrip | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment is a bit confusing (same in L368). Maybe use something like this instead: |
||
prepared, err := mc.interpolateParams(query, args) | ||
if err != nil { | ||
return nil, err | ||
} | ||
query = prepared | ||
args = nil | ||
} | ||
mc.affectedRows = 0 | ||
mc.insertId = 0 | ||
|
||
// with args, must use prepared stmt | ||
return nil, driver.ErrSkip | ||
|
||
err := mc.exec(query) | ||
if err == nil { | ||
return &mysqlResult{ | ||
affectedRows: int64(mc.affectedRows), | ||
insertId: int64(mc.insertId), | ||
}, err | ||
} | ||
return nil, err | ||
} | ||
|
||
// Internal function to execute commands | ||
|
@@ -212,31 +361,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro | |
errLog.Print(ErrInvalidConn) | ||
return nil, driver.ErrBadConn | ||
} | ||
if len(args) == 0 { // no args, fastpath | ||
// Send command | ||
err := mc.writeCommandPacketStr(comQuery, query) | ||
if len(args) != 0 { | ||
if !mc.cfg.interpolateParams { | ||
return nil, driver.ErrSkip | ||
} | ||
// try client-side prepare to reduce roundtrip | ||
prepared, err := mc.interpolateParams(query, args) | ||
if err != nil { | ||
return nil, err | ||
} | ||
query = prepared | ||
args = nil | ||
} | ||
// Send command | ||
err := mc.writeCommandPacketStr(comQuery, query) | ||
if err == nil { | ||
// Read Result | ||
var resLen int | ||
resLen, err = mc.readResultSetHeaderPacket() | ||
if err == nil { | ||
// Read Result | ||
var resLen int | ||
resLen, err = mc.readResultSetHeaderPacket() | ||
if err == nil { | ||
rows := new(textRows) | ||
rows.mc = mc | ||
|
||
if resLen == 0 { | ||
// no columns, no more data | ||
return emptyRows{}, nil | ||
} | ||
// Columns | ||
rows.columns, err = mc.readColumns(resLen) | ||
return rows, err | ||
rows := new(textRows) | ||
rows.mc = mc | ||
|
||
if resLen == 0 { | ||
// no columns, no more data | ||
return emptyRows{}, nil | ||
} | ||
// Columns | ||
rows.columns, err = mc.readColumns(resLen) | ||
return rows, err | ||
} | ||
return nil, err | ||
} | ||
|
||
// with args, must use prepared stmt | ||
return nil, driver.ErrSkip | ||
return nil, err | ||
} | ||
|
||
// Gets the value of the given MySQL System Variable | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use just
db.Query()
/db.Exec()
instead ofsql.Db.Query()
/sql.Db.Exec()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
interpolateParams
is true, placeholders (?
) in calls todb.Query()
anddb.Exec()
are interpolated into a single query string with the given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with the given parameters and close the statement again withinterpolateParams=false
.