From e35fa001b5162eccbcbd2d0f4a722399fa454397 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 1 Jan 2015 05:27:47 +0900 Subject: [PATCH 01/37] Implement placeholder substitution. --- connection.go | 134 ++++++++++++++++++++++++++++++++++++++++++++------ const.go | 22 +++++++++ packets.go | 1 + 3 files changed, 143 insertions(+), 14 deletions(-) diff --git a/connection.go b/connection.go index 67d3dbee8..44f3ddde5 100644 --- a/connection.go +++ b/connection.go @@ -13,6 +13,7 @@ import ( "database/sql/driver" "errors" "net" + "strconv" "strings" "time" ) @@ -26,6 +27,7 @@ type mysqlConn struct { maxPacketAllowed int maxWriteSize int flags clientFlag + status statusFlag sequence uint8 parseTime bool strict bool @@ -161,28 +163,132 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { return stmt, err } +func (mc *mysqlConn) escapeBytes(v []byte) string { + buf := make([]byte, len(v)*2) + buf[0] = '\'' + pos := 1 + if mc.status&statusNoBackslashEscapes == 0 { + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos += 1 + } + } + } else { + for _, c := range v { + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + } + buf[pos] = '\'' + return string(buf[:pos+1]) +} + +func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) { + chunks := strings.Split(query, "?") + if len(chunks) != len(args)+1 { + return "", driver.ErrSkip + } + + parts := make([]string, len(chunks)+len(args)) + parts[0] = chunks[0] + + for i, arg := range args { + pos := i*2 + 1 + parts[pos+1] = chunks[i+1] + if arg == nil { + parts[pos] = "NULL" + continue + } + switch v := arg.(type) { + case int64: + parts[pos] = strconv.FormatInt(v, 10) + case float64: + parts[pos] = strconv.FormatFloat(v, 'f', -1, 64) + case bool: + if v { + parts[pos] = "1" + } else { + parts[pos] = "0" + } + case time.Time: + if v.IsZero() { + parts[pos] = "'0000-00-00'" + } else { + fmt := "'2006-01-02 15:04:05.999999'" + parts[pos] = v.In(mc.cfg.loc).Format(fmt) + } + case []byte: + parts[pos] = mc.escapeBytes(v) + case string: + parts[pos] = mc.escapeBytes([]byte(v)) + default: + return "", driver.ErrSkip + } + } + return strings.Join(parts, ""), nil +} + func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - if len(args) == 0 { // no args, fastpath - mc.affectedRows = 0 - mc.insertId = 0 - - err := mc.exec(query) - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, err + if len(args) != 0 { + // try client-side prepare to reduce roundtrip + prepared, err := mc.buildQuery(query, args) + if err != nil { + return nil, err } - return nil, err + query = prepared + args = nil } + mc.affectedRows = 0 + mc.insertId = 0 - // with args, must use prepared stmt - return nil, driver.ErrSkip - + err := mc.exec(query) + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, err + } + return nil, err } // Internal function to execute commands diff --git a/const.go b/const.go index 5fcc3e98b..3aeaf1b1b 100644 --- a/const.go +++ b/const.go @@ -130,3 +130,25 @@ const ( flagUnknown3 flagUnknown4 ) + +// http://dev.mysql.com/doc/internals/en/status-flags.html + +type statusFlag uint16 + +const ( + statusInTrans statusFlag = 1 << iota + statusInAutocommit + statusUnknown1 + statusMoreResultsExists + statusNoGoodIndexUsed + statusNoIndexUsed + statusCursorExists + statusLastRowSent + statusDbDropped + statusNoBackslashEscapes + statusMetadataChanged + statusQueryWasSlow + statusPsOutParams + statusInTransReadonly + statusSessionStateChanged +) diff --git a/packets.go b/packets.go index f2e385bf8..49c6f6966 100644 --- a/packets.go +++ b/packets.go @@ -484,6 +484,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] + mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8 // warning count [2 bytes] if !mc.strict { From c8c9bb1ec8ac8f25a4c047ae996541b5c0968f5a Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 1 Jan 2015 06:03:50 +0900 Subject: [PATCH 02/37] Query() uses client-side placeholder substitution. --- connection.go | 54 +++++++++++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/connection.go b/connection.go index 44f3ddde5..d8a2fa369 100644 --- a/connection.go +++ b/connection.go @@ -164,7 +164,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } func (mc *mysqlConn) escapeBytes(v []byte) string { - buf := make([]byte, len(v)*2) + buf := make([]byte, len(v)*2+2) buf[0] = '\'' pos := 1 if mc.status&statusNoBackslashEscapes == 0 { @@ -254,7 +254,11 @@ func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, erro parts[pos] = v.In(mc.cfg.loc).Format(fmt) } case []byte: - parts[pos] = mc.escapeBytes(v) + if v == nil { + parts[pos] = "NULL" + } else { + parts[pos] = mc.escapeBytes(v) + } case string: parts[pos] = mc.escapeBytes([]byte(v)) default: @@ -317,31 +321,35 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - if len(args) == 0 { // no args, fastpath - // Send command - err := mc.writeCommandPacketStr(comQuery, query) + if len(args) != 0 { + // try client-side prepare to reduce roundtrip + prepared, err := mc.buildQuery(query, args) + if err != nil { + return nil, err + } + query = prepared + args = nil + } + // Send command + err := mc.writeCommandPacketStr(comQuery, query) + if err == nil { + // Read Result + var resLen int + resLen, err = mc.readResultSetHeaderPacket() if err == nil { - // Read Result - var resLen int - resLen, err = mc.readResultSetHeaderPacket() - if err == nil { - rows := new(textRows) - rows.mc = mc - - if resLen == 0 { - // no columns, no more data - return emptyRows{}, nil - } - // Columns - rows.columns, err = mc.readColumns(resLen) - return rows, err + rows := new(textRows) + rows.mc = mc + + if resLen == 0 { + // no columns, no more data + return emptyRows{}, nil } + // Columns + rows.columns, err = mc.readColumns(resLen) + return rows, err } - return nil, err } - - // with args, must use prepared stmt - return nil, driver.ErrSkip + return nil, err } // Gets the value of the given MySQL System Variable From cac6129f8a5c4cc81fa9c52c262501c93092c61d Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 20 Jan 2015 12:17:21 +0900 Subject: [PATCH 03/37] Don't send text query larger than maxPacketAllowed --- connection.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/connection.go b/connection.go index d8a2fa369..f3bf1aac7 100644 --- a/connection.go +++ b/connection.go @@ -265,6 +265,13 @@ func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, erro return "", driver.ErrSkip } } + pktSize := len(query) + 4 // 4 bytes for header. + for _, p := range parts { + pktSize += len(p) + } + if pktSize > mc.maxPacketAllowed { + return "", driver.ErrSkip + } return strings.Join(parts, ""), nil } From b7c2c47a361f9c038e61f98d178a09e4bf5d36d9 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 20 Jan 2015 12:56:48 +0900 Subject: [PATCH 04/37] Add substitutePlaceholder option to DSN --- connection.go | 33 ++++++++++++++++++++------------- driver_test.go | 9 +++++++++ utils.go | 8 ++++++++ utils_test.go | 24 ++++++++++++------------ 4 files changed, 49 insertions(+), 25 deletions(-) diff --git a/connection.go b/connection.go index f3bf1aac7..b51616cce 100644 --- a/connection.go +++ b/connection.go @@ -34,19 +34,20 @@ type mysqlConn struct { } type config struct { - user string - passwd string - net string - addr string - dbname string - params map[string]string - loc *time.Location - tls *tls.Config - timeout time.Duration - collation uint8 - allowAllFiles bool - allowOldPasswords bool - clientFoundRows bool + user string + passwd string + net string + addr string + dbname string + params map[string]string + loc *time.Location + tls *tls.Config + timeout time.Duration + collation uint8 + allowAllFiles bool + allowOldPasswords bool + clientFoundRows bool + substitutePlaceholder bool } // Handles parameters set in DSN after the connection is established @@ -281,6 +282,9 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, driver.ErrBadConn } if len(args) != 0 { + if !mc.cfg.substitutePlaceholder { + return nil, driver.ErrSkip + } // try client-side prepare to reduce roundtrip prepared, err := mc.buildQuery(query, args) if err != nil { @@ -329,6 +333,9 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, driver.ErrBadConn } if len(args) != 0 { + if !mc.cfg.substitutePlaceholder { + return nil, driver.ErrSkip + } // try client-side prepare to reduce roundtrip prepared, err := mc.buildQuery(query, args) if err != nil { diff --git a/driver_test.go b/driver_test.go index a52cc5cd0..f0777d79d 100644 --- a/driver_test.go +++ b/driver_test.go @@ -87,10 +87,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db.Exec("DROP TABLE IF EXISTS test") + dbp, err := sql.Open("mysql", dsn+"&substitutePlaceholder=true") + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + defer dbp.Close() + dbt := &DBTest{t, db} + dbtp := &DBTest{t, dbp} for _, test := range tests { test(dbt) dbt.db.Exec("DROP TABLE IF EXISTS test") + test(dbtp) + dbtp.db.Exec("DROP TABLE IF EXISTS test") } } diff --git a/utils.go b/utils.go index 98dfc6f5e..008e9abc1 100644 --- a/utils.go +++ b/utils.go @@ -180,6 +180,14 @@ func parseDSNParams(cfg *config, params string) (err error) { // cfg params switch value := param[1]; param[0] { + // Enable client side placeholder substitution + case "substitutePlaceholder": + var isBool bool + cfg.substitutePlaceholder, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool diff --git a/utils_test.go b/utils_test.go index 0855374b7..4f666cb96 100644 --- a/utils_test.go +++ b/utils_test.go @@ -22,18 +22,18 @@ var testDSNs = []struct { out string loc *time.Location }{ - {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local}, - {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true substitutePlaceholder:false}", time.UTC}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.Local}, + {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, } func TestDSNParser(t *testing.T) { From 058ce87948f074a3601619c69c4df3cfaf0f3bfe Mon Sep 17 00:00:00 2001 From: arvenil Date: Sun, 1 Feb 2015 20:38:07 +0100 Subject: [PATCH 05/37] Move escape funcs to utils.go, export them, add references to mysql surce code --- connection.go | 55 ++++----------------------------------- utils.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 51 deletions(-) diff --git a/connection.go b/connection.go index 0ddfed9bc..e51182eed 100644 --- a/connection.go +++ b/connection.go @@ -165,60 +165,15 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { return stmt, err } +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156 func (mc *mysqlConn) escapeBytes(v []byte) string { - buf := make([]byte, len(v)*2+2) - buf[0] = '\'' - pos := 1 + var escape func([]byte) []byte if mc.status&statusNoBackslashEscapes == 0 { - for _, c := range v { - switch c { - case '\x00': - buf[pos] = '\\' - buf[pos+1] = '0' - pos += 2 - case '\n': - buf[pos] = '\\' - buf[pos+1] = 'n' - pos += 2 - case '\r': - buf[pos] = '\\' - buf[pos+1] = 'r' - pos += 2 - case '\x1a': - buf[pos] = '\\' - buf[pos+1] = 'Z' - pos += 2 - case '\'': - buf[pos] = '\\' - buf[pos+1] = '\'' - pos += 2 - case '"': - buf[pos] = '\\' - buf[pos+1] = '"' - pos += 2 - case '\\': - buf[pos] = '\\' - buf[pos+1] = '\\' - pos += 2 - default: - buf[pos] = c - pos += 1 - } - } + escape = EscapeString } else { - for _, c := range v { - if c == '\'' { - buf[pos] = '\'' - buf[pos+1] = '\'' - pos += 2 - } else { - buf[pos] = c - pos++ - } - } + escape = EscapeQuotes } - buf[pos] = '\'' - return string(buf[:pos+1]) + return "'" + string(escape(v)) + "'" } func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) { diff --git a/utils.go b/utils.go index f83dc80c9..b2da8af35 100644 --- a/utils.go +++ b/utils.go @@ -224,7 +224,7 @@ func parseDSNParams(cfg *config, params string) (err error) { } cfg.collation = collation break - + case "columnsWithAlias": var isBool bool cfg.columnsWithAlias, isBool = readBool(value) @@ -806,3 +806,72 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } + +// Escape string with backslashes (\) +// This escapes the contents of a string (provided as []byte) by adding backslashes before special +// characters, and turning others into specific escape sequences, such as +// turning newlines into \n and null bytes into \0. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 +func EscapeString(v []byte) []byte { + buf := make([]byte, len(v)*2) + pos := 0 + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos += 1 + } + } + + return buf[:pos] +} + +// Escape apostrophes by doubling them up +// This escapes the contents of a string by doubling up any apostrophes that +// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in +// effect on the server. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 +func EscapeQuotes(v []byte) []byte { + buf := make([]byte, len(v)*2) + pos := 0 + for _, c := range v { + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} From 42956fa9ec8c80afcac7ad0624973953ea5dd94c Mon Sep 17 00:00:00 2001 From: arvenil Date: Sun, 1 Feb 2015 22:13:02 +0100 Subject: [PATCH 06/37] Add tests for escaping functions --- utils_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/utils_test.go b/utils_test.go index 55bd9ba47..04ddacf8f 100644 --- a/utils_test.go +++ b/utils_test.go @@ -252,3 +252,42 @@ func TestFormatBinaryDateTime(t *testing.T) { expect("1978-12-30 15:46:23", 7, 19) expect("1978-12-30 15:46:23.987654", 11, 26) } + +func TestEscapeString(t *testing.T) { + expect := func(expected, value string) { + actual := string(EscapeString([]byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\\0bar", "foo\x00bar") + expect("foo\\nbar", "foo\nbar") + expect("foo\\rbar", "foo\rbar") + expect("foo\\Zbar", "foo\x1abar") + expect("foo\\\"bar", "foo\"bar") + expect("foo\\\\bar", "foo\\bar") + expect("foo\\'bar", "foo'bar") +} + +func TestEscapeQuotes(t *testing.T) { + expect := func(expected, value string) { + actual := string(EscapeQuotes([]byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\x00bar", "foo\x00bar") // not affected + expect("foo\nbar", "foo\nbar") // not affected + expect("foo\rbar", "foo\rbar") // not affected + expect("foo\x1abar", "foo\x1abar") // not affected + expect("foo''bar", "foo'bar") // affected + expect("foo\"bar", "foo\"bar") // not affected +} From e6bf23ab5050085f8be7852211dde710add2f82e Mon Sep 17 00:00:00 2001 From: arvenil Date: Sun, 1 Feb 2015 23:33:14 +0100 Subject: [PATCH 07/37] Add basic SQL injection tests, including NO_BACKSLASH_ESCAPES sql_mode --- driver_test.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/driver_test.go b/driver_test.go index f0777d79d..0b0080e23 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1547,3 +1547,33 @@ func TestCustomDial(t *testing.T) { t.Fatalf("Connection failed: %s", err.Error()) } } + +func TestSqlInjection(t *testing.T) { + createTest := func(arg string) func(dbt *DBTest) { + return func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (?)", 1) + + var v int + // NULL can't be equal to anything, the idea here is to inject query so it returns row + // This test verifies that EscapeQuotes and EscapeStrings are working properly + err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) + if err == sql.ErrNoRows { + return // success, sql injection failed + } else if err == nil { + dbt.Errorf("Sql injection successful with arg: %s", arg) + } else { + dbt.Errorf("Error running query with arg: %s; err: %s", err.Error()) + } + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode=NO_BACKSLASH_ESCAPES", + } + for _, testdsn := range dsns { + runTests(t, testdsn, createTest("1 OR 1=1")) + runTests(t, testdsn, createTest("' OR '1'='1")) + } +} From b4732595f010ffc0f234f94af2f01c24ea1c75cc Mon Sep 17 00:00:00 2001 From: arvenil Date: Sat, 7 Feb 2015 16:31:26 +0100 Subject: [PATCH 08/37] Test if inserted data is correctly retrieved after being escaped --- driver_test.go | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/driver_test.go b/driver_test.go index 0b0080e23..bf44ff9f2 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1563,7 +1563,7 @@ func TestSqlInjection(t *testing.T) { } else if err == nil { dbt.Errorf("Sql injection successful with arg: %s", arg) } else { - dbt.Errorf("Error running query with arg: %s; err: %s", err.Error()) + dbt.Errorf("Error running query with arg: %s; err: %s", arg, err.Error()) } } } @@ -1577,3 +1577,32 @@ func TestSqlInjection(t *testing.T) { runTests(t, testdsn, createTest("' OR '1'='1")) } } + +// Test if inserted data is correctly retrieved after being escaped +func TestInsertRetrieveEscapedData(t *testing.T) { + testData := func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v VARCHAR(255))") + + // All sequences that are escaped by EscapeQuotes and EscapeString + v := "foo \x00\n\r\x1a\"'\\" + dbt.mustExec("INSERT INTO test VALUES (?)", v) + + var out string + err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + if out != v { + dbt.Errorf("%q != %q", out, v) + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode=NO_BACKSLASH_ESCAPES", + } + for _, testdsn := range dsns { + runTests(t, testdsn, testData) + } +} From 42a1efd12acf579f94495c296c30a6ef9d62706d Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Sun, 8 Feb 2015 17:18:00 +0900 Subject: [PATCH 09/37] Don't stop test on MySQLWarnings "DROP TABLE IF EXISTS ..." query fails on fresh database. --- benchmark_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmark_test.go b/benchmark_test.go index d72a4183f..a62efd0c4 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -45,7 +45,11 @@ func initDB(b *testing.B, queries ...string) *sql.DB { db := tb.checkDB(sql.Open("mysql", dsn)) for _, query := range queries { if _, err := db.Exec(query); err != nil { - b.Fatalf("Error on %q: %v", query, err) + if w, ok := err.(MySQLWarnings); ok { + b.Logf("Warning on %q: %v", query, w) + } else { + b.Fatalf("Error on %q: %v", query, err) + } } } return db From 3c8fa904c29c1648a2ff6420213aa677cb362e93 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Sun, 8 Feb 2015 19:51:20 +0900 Subject: [PATCH 10/37] substitutePlaceholder -> interpolateParams --- connection.go | 34 +++++++++++++++++----------------- driver_test.go | 2 +- utils.go | 4 ++-- utils_test.go | 26 +++++++++++++------------- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/connection.go b/connection.go index e51182eed..0bbdcebe8 100644 --- a/connection.go +++ b/connection.go @@ -34,21 +34,21 @@ type mysqlConn struct { } type config struct { - user string - passwd string - net string - addr string - dbname string - params map[string]string - loc *time.Location - tls *tls.Config - timeout time.Duration - collation uint8 - allowAllFiles bool - allowOldPasswords bool - clientFoundRows bool - columnsWithAlias bool - substitutePlaceholder bool + user string + passwd string + net string + addr string + dbname string + params map[string]string + loc *time.Location + tls *tls.Config + timeout time.Duration + collation uint8 + allowAllFiles bool + allowOldPasswords bool + clientFoundRows bool + columnsWithAlias bool + interpolateParams bool } // Handles parameters set in DSN after the connection is established @@ -238,7 +238,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, driver.ErrBadConn } if len(args) != 0 { - if !mc.cfg.substitutePlaceholder { + if !mc.cfg.interpolateParams { return nil, driver.ErrSkip } // try client-side prepare to reduce roundtrip @@ -289,7 +289,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, driver.ErrBadConn } if len(args) != 0 { - if !mc.cfg.substitutePlaceholder { + if !mc.cfg.interpolateParams { return nil, driver.ErrSkip } // try client-side prepare to reduce roundtrip diff --git a/driver_test.go b/driver_test.go index bf44ff9f2..c9779f880 100644 --- a/driver_test.go +++ b/driver_test.go @@ -87,7 +87,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db.Exec("DROP TABLE IF EXISTS test") - dbp, err := sql.Open("mysql", dsn+"&substitutePlaceholder=true") + dbp, err := sql.Open("mysql", dsn+"&interpolateParams=true") if err != nil { t.Fatalf("Error connecting: %s", err.Error()) } diff --git a/utils.go b/utils.go index b2da8af35..a2136271e 100644 --- a/utils.go +++ b/utils.go @@ -181,9 +181,9 @@ func parseDSNParams(cfg *config, params string) (err error) { switch value := param[1]; param[0] { // Enable client side placeholder substitution - case "substitutePlaceholder": + case "interpolateParams": var isBool bool - cfg.substitutePlaceholder, isBool = readBool(value) + cfg.interpolateParams, isBool = readBool(value) if !isBool { return fmt.Errorf("Invalid Bool value: %s", value) } diff --git a/utils_test.go b/utils_test.go index 04ddacf8f..7fa039e4c 100644 --- a/utils_test.go +++ b/utils_test.go @@ -22,19 +22,19 @@ var testDSNs = []struct { out string loc *time.Location }{ - {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true substitutePlaceholder:false}", time.UTC}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.Local}, - {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, - {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC}, + {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true interpolateParams:false}", time.UTC}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.Local}, + {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, } func TestDSNParser(t *testing.T) { From 6c8484b12c0fffbd3ea0cd72881a0263f316b1ef Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Sun, 8 Feb 2015 20:08:03 +0900 Subject: [PATCH 11/37] Add interpolateParams document to README --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index a42d3eb39..1d008cf7a 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,18 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `interpolateParams` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +When `interpolateParams` is true, calls to `sql.Db.Query()` and `sql.Db.Exec()` with params interpolates placeholders (`?`) with given params. This reduces roundtrips to database compared with `interpolateParams=false` since it uses prapre, exec and close to support parameters. + +NOTE: It make SQL injection vulnerability when connection encoding is not utf8. + ##### `loc` ``` From 04866ee036f56ec3cfd73d4bd68452b550dafb50 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 9 Feb 2015 00:50:30 +0900 Subject: [PATCH 12/37] Fix nits pointed in pull request. --- connection.go | 8 ++++---- const.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/connection.go b/connection.go index 0bbdcebe8..1e259424e 100644 --- a/connection.go +++ b/connection.go @@ -176,7 +176,7 @@ func (mc *mysqlConn) escapeBytes(v []byte) string { return "'" + string(escape(v)) + "'" } -func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) { +func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { chunks := strings.Split(query, "?") if len(chunks) != len(args)+1 { return "", driver.ErrSkip @@ -196,7 +196,7 @@ func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, erro case int64: parts[pos] = strconv.FormatInt(v, 10) case float64: - parts[pos] = strconv.FormatFloat(v, 'f', -1, 64) + parts[pos] = strconv.FormatFloat(v, 'g', -1, 64) case bool: if v { parts[pos] = "1" @@ -242,7 +242,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, driver.ErrSkip } // try client-side prepare to reduce roundtrip - prepared, err := mc.buildQuery(query, args) + prepared, err := mc.interpolateParams(query, args) if err != nil { return nil, err } @@ -293,7 +293,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, driver.ErrSkip } // try client-side prepare to reduce roundtrip - prepared, err := mc.buildQuery(query, args) + prepared, err := mc.interpolateParams(query, args) if err != nil { return nil, err } diff --git a/const.go b/const.go index 3aeaf1b1b..7bf5cea3d 100644 --- a/const.go +++ b/const.go @@ -138,7 +138,7 @@ type statusFlag uint16 const ( statusInTrans statusFlag = 1 << iota statusInAutocommit - statusUnknown1 + statusReserved // Not in documentation statusMoreResultsExists statusNoGoodIndexUsed statusNoIndexUsed From dd7b87c50b971b330172d3905658157189f7b172 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 9 Feb 2015 01:16:57 +0900 Subject: [PATCH 13/37] Add benchmark for interpolateParams() --- benchmark_test.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/benchmark_test.go b/benchmark_test.go index a62efd0c4..94e44b0ea 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -11,10 +11,13 @@ package mysql import ( "bytes" "database/sql" + "database/sql/driver" + "math" "strings" "sync" "sync/atomic" "testing" + "time" ) type TB testing.B @@ -210,3 +213,27 @@ func BenchmarkRoundtripBin(b *testing.B) { rows.Close() } } + +func BenchmarkInterpolation(b *testing.B) { + mc := &mysqlConn{ + cfg: &config{interpolateParams: true}, + maxPacketAllowed: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + } + + args := []driver.Value{ + 42424242, + math.Pi, + false, + time.Unix(1423411542, 807015000), + []byte("bytes containing special chars ' \" \a \x00"), + "string containing special chars ' \" \a \x00", + } + q := "SELECT ?, ?, ?, ?, ?, ?" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + mc.interpolateParams(q, args) + } +} From 9faabe593dfff7f347a18f9adef79cba24e51119 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 9 Feb 2015 01:17:20 +0900 Subject: [PATCH 14/37] Don't write microseconds when Time.Nanosecond() == 0 --- connection.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/connection.go b/connection.go index 1e259424e..3ac55cb87 100644 --- a/connection.go +++ b/connection.go @@ -208,6 +208,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin parts[pos] = "'0000-00-00'" } else { fmt := "'2006-01-02 15:04:05.999999'" + if v.Nanosecond() == 0 { + fmt = "'2006-01-02 15:04:05'" + } parts[pos] = v.In(mc.cfg.loc).Format(fmt) } case []byte: From 468b9e5379a395a72a5f3464e5c913d3e3c4bbf6 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 9 Feb 2015 01:32:21 +0900 Subject: [PATCH 15/37] Fix benchmark --- benchmark_test.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 94e44b0ea..623ba28fa 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -216,14 +216,17 @@ func BenchmarkRoundtripBin(b *testing.B) { func BenchmarkInterpolation(b *testing.B) { mc := &mysqlConn{ - cfg: &config{interpolateParams: true}, + cfg: &config{ + interpolateParams: true, + loc: time.UTC, + }, maxPacketAllowed: maxPacketSize, maxWriteSize: maxPacketSize - 1, } args := []driver.Value{ - 42424242, - math.Pi, + int64(42424242), + float64(math.Pi), false, time.Unix(1423411542, 807015000), []byte("bytes containing special chars ' \" \a \x00"), @@ -234,6 +237,9 @@ func BenchmarkInterpolation(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - mc.interpolateParams(q, args) + _, err := mc.interpolateParams(q, args) + if err != nil { + b.Fatal(err) + } } } From 029731571eb8356fcad551516f98a8dc9d6d3edb Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 9 Feb 2015 02:16:46 +0900 Subject: [PATCH 16/37] Reduce allocs in interpolateParams. benchmark old ns/op new ns/op delta BenchmarkInterpolation 4065 2533 -37.69% benchmark old allocs new allocs delta BenchmarkInterpolation 15 6 -60.00% benchmark old bytes new bytes delta BenchmarkInterpolation 1144 560 -51.05% --- connection.go | 90 +++++++++++++++++++++++++++++++++++---------------- utils.go | 26 +++++++++++---- utils_test.go | 4 +-- 3 files changed, 84 insertions(+), 36 deletions(-) diff --git a/connection.go b/connection.go index 3ac55cb87..cb9106690 100644 --- a/connection.go +++ b/connection.go @@ -166,73 +166,107 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156 -func (mc *mysqlConn) escapeBytes(v []byte) string { - var escape func([]byte) []byte +func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte { + var escape func([]byte, []byte) []byte if mc.status&statusNoBackslashEscapes == 0 { - escape = EscapeString + escape = escapeString } else { - escape = EscapeQuotes + escape = escapeQuotes } - return "'" + string(escape(v)) + "'" + buf = append(buf, '\'') + buf = escape(buf, v) + buf = append(buf, '\'') + return buf +} + +func estimateParamLength(args []driver.Value) (int, bool) { + l := 0 + for _, a := range args { + switch v := a.(type) { + case int64, float64: + l += 20 + case bool: + l += 5 + case time.Time: + l += 30 + case string: + l += len(v)*2 + 2 + case []byte: + l += len(v)*2 + 2 + default: + return 0, false + } + } + return l, true } func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { - chunks := strings.Split(query, "?") - if len(chunks) != len(args)+1 { + estimated, ok := estimateParamLength(args) + if !ok { return "", driver.ErrSkip } + estimated += len(query) - parts := make([]string, len(chunks)+len(args)) - parts[0] = chunks[0] + buf := make([]byte, 0, estimated) + argPos := 0 + + // Go 1.5 will optimize range([]byte(string)) to skip allocation. + for _, c := range []byte(query) { + if c != '?' { + buf = append(buf, c) + continue + } + + arg := args[argPos] + argPos++ - for i, arg := range args { - pos := i*2 + 1 - parts[pos+1] = chunks[i+1] if arg == nil { - parts[pos] = "NULL" + buf = append(buf, []byte("NULL")...) continue } + switch v := arg.(type) { case int64: - parts[pos] = strconv.FormatInt(v, 10) + buf = strconv.AppendInt(buf, v, 10) case float64: - parts[pos] = strconv.FormatFloat(v, 'g', -1, 64) + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) case bool: if v { - parts[pos] = "1" + buf = append(buf, '1') } else { - parts[pos] = "0" + buf = append(buf, '0') } case time.Time: if v.IsZero() { - parts[pos] = "'0000-00-00'" + buf = append(buf, []byte("'0000-00-00'")...) } else { fmt := "'2006-01-02 15:04:05.999999'" if v.Nanosecond() == 0 { fmt = "'2006-01-02 15:04:05'" } - parts[pos] = v.In(mc.cfg.loc).Format(fmt) + s := v.In(mc.cfg.loc).Format(fmt) + buf = append(buf, []byte(s)...) } case []byte: if v == nil { - parts[pos] = "NULL" + buf = append(buf, []byte("NULL")...) } else { - parts[pos] = mc.escapeBytes(v) + buf = mc.escapeBytes(buf, v) } case string: - parts[pos] = mc.escapeBytes([]byte(v)) + buf = mc.escapeBytes(buf, []byte(v)) default: return "", driver.ErrSkip } + + if len(buf)+4 > mc.maxPacketAllowed { + return "", driver.ErrSkip + } } - pktSize := len(query) + 4 // 4 bytes for header. - for _, p := range parts { - pktSize += len(p) - } - if pktSize > mc.maxPacketAllowed { + if argPos != len(args) { return "", driver.ErrSkip } - return strings.Join(parts, ""), nil + return string(buf), nil } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { diff --git a/utils.go b/utils.go index a2136271e..e7b15f1ff 100644 --- a/utils.go +++ b/utils.go @@ -812,9 +812,16 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { // characters, and turning others into specific escape sequences, such as // turning newlines into \n and null bytes into \0. // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 -func EscapeString(v []byte) []byte { - buf := make([]byte, len(v)*2) - pos := 0 +func escapeString(buf, v []byte) []byte { + pos := len(buf) + end := pos + len(v)*2 + if cap(buf) < end { + n := make([]byte, pos+end) + copy(n, buf) + buf = n + } + buf = buf[0:end] + for _, c := range v { switch c { case '\x00': @@ -859,9 +866,16 @@ func EscapeString(v []byte) []byte { // it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in // effect on the server. // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 -func EscapeQuotes(v []byte) []byte { - buf := make([]byte, len(v)*2) - pos := 0 +func escapeQuotes(buf, v []byte) []byte { + pos := len(buf) + end := pos + len(v)*2 + if cap(buf) < end { + n := make([]byte, pos+end) + copy(n, buf) + buf = n + } + buf = buf[0:end] + for _, c := range v { if c == '\'' { buf[pos] = '\'' diff --git a/utils_test.go b/utils_test.go index 7fa039e4c..95b91964e 100644 --- a/utils_test.go +++ b/utils_test.go @@ -255,7 +255,7 @@ func TestFormatBinaryDateTime(t *testing.T) { func TestEscapeString(t *testing.T) { expect := func(expected, value string) { - actual := string(EscapeString([]byte(value))) + actual := string(escapeString([]byte{}, []byte(value))) if actual != expected { t.Errorf( "expected %s, got %s", @@ -275,7 +275,7 @@ func TestEscapeString(t *testing.T) { func TestEscapeQuotes(t *testing.T) { expect := func(expected, value string) { - actual := string(EscapeQuotes([]byte(value))) + actual := string(escapeQuotes([]byte{}, []byte(value))) if actual != expected { t.Errorf( "expected %s, got %s", From 0b753962323303ce4a9625b66fdcf47b0395202b Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 9 Feb 2015 03:56:26 +0900 Subject: [PATCH 17/37] Inline datetime formatting benchmark old ns/op new ns/op delta BenchmarkInterpolation 2536 2209 -12.89% benchmark old allocs new allocs delta BenchmarkInterpolation 6 4 -33.33% benchmark old bytes new bytes delta BenchmarkInterpolation 560 496 -11.43% --- connection.go | 48 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/connection.go b/connection.go index cb9106690..c51864221 100644 --- a/connection.go +++ b/connection.go @@ -240,12 +240,50 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if v.IsZero() { buf = append(buf, []byte("'0000-00-00'")...) } else { - fmt := "'2006-01-02 15:04:05.999999'" - if v.Nanosecond() == 0 { - fmt = "'2006-01-02 15:04:05'" + v := v.In(mc.cfg.loc) + year := v.Year() + month := v.Month() + day := v.Day() + hour := v.Hour() + minute := v.Minute() + second := v.Second() + micro := v.Nanosecond() / 1000 + + buf = append(buf, []byte{ + byte('\''), + byte('0' + year/1000), + byte('0' + year/100%10), + byte('0' + year/10%10), + byte('0' + year%10), + byte('-'), + byte('0' + month/10), + byte('0' + month%10), + byte('-'), + byte('0' + day/10), + byte('0' + day%10), + byte(' '), + byte('0' + hour/10), + byte('0' + hour%10), + byte(':'), + byte('0' + minute/10), + byte('0' + minute%10), + byte(':'), + byte('0' + second/10), + byte('0' + second%10), + }...) + + if micro != 0 { + buf = append(buf, []byte{ + byte('.'), + byte('0' + micro/100000), + byte('0' + micro/10000%10), + byte('0' + micro/1000%10), + byte('0' + micro/100%10), + byte('0' + micro/10%10), + byte('0' + micro%10), + }...) } - s := v.In(mc.cfg.loc).Format(fmt) - buf = append(buf, []byte(s)...) + buf = append(buf, '\'') } case []byte: if v == nil { From 9f84dfbb88104a301bbd7a04dd0299bd35759c3a Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 9 Feb 2015 03:58:07 +0900 Subject: [PATCH 18/37] Remove one more allocation benchmark old ns/op new ns/op delta BenchmarkInterpolation 2209 2116 -4.21% benchmark old allocs new allocs delta BenchmarkInterpolation 4 3 -25.00% benchmark old bytes new bytes delta BenchmarkInterpolation 496 464 -6.45% --- connection.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/connection.go b/connection.go index c51864221..07b20e086 100644 --- a/connection.go +++ b/connection.go @@ -210,8 +210,8 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf := make([]byte, 0, estimated) argPos := 0 - // Go 1.5 will optimize range([]byte(string)) to skip allocation. - for _, c := range []byte(query) { + for i := 0; i < len(query); i++ { + c := query[i] if c != '?' { buf = append(buf, c) continue From 8826242dabcd4120cfafdb9aab4bf66a7aa17150 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 9 Feb 2015 04:11:45 +0900 Subject: [PATCH 19/37] More acculate estimation of upper bound --- connection.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/connection.go b/connection.go index 07b20e086..2b259faec 100644 --- a/connection.go +++ b/connection.go @@ -179,16 +179,18 @@ func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte { return buf } +// estimateParamLength calculates upper bound of string length from types. func estimateParamLength(args []driver.Value) (int, bool) { l := 0 for _, a := range args { switch v := a.(type) { case int64, float64: - l += 20 + // 24 (-1.7976931348623157e+308) may be upper bound. But I'm not sure. + l += 25 case bool: - l += 5 + l += 1 // 0 or 1 case time.Time: - l += 30 + l += 30 // '1234-12-23 12:34:56.777777' case string: l += len(v)*2 + 2 case []byte: From 916a1f24337bffa9de03843ff5e29f1da23d57ca Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 9 Feb 2015 11:38:52 +0900 Subject: [PATCH 20/37] escapeString -> escapeBackslash --- connection.go | 11 ++++------- driver_test.go | 4 ++-- utils.go | 2 +- utils_test.go | 4 ++-- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/connection.go b/connection.go index 2b259faec..55fbf930e 100644 --- a/connection.go +++ b/connection.go @@ -167,16 +167,13 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156 func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte { - var escape func([]byte, []byte) []byte + buf = append(buf, '\'') if mc.status&statusNoBackslashEscapes == 0 { - escape = escapeString + buf = escapeBackslash(buf, v) } else { - escape = escapeQuotes + buf = escapeQuotes(buf, v) } - buf = append(buf, '\'') - buf = escape(buf, v) - buf = append(buf, '\'') - return buf + return append(buf, '\'') } // estimateParamLength calculates upper bound of string length from types. diff --git a/driver_test.go b/driver_test.go index c9779f880..bda07d63a 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1556,7 +1556,7 @@ func TestSqlInjection(t *testing.T) { var v int // NULL can't be equal to anything, the idea here is to inject query so it returns row - // This test verifies that EscapeQuotes and EscapeStrings are working properly + // This test verifies that escapeQuotes and escapeBackslash are working properly err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) if err == sql.ErrNoRows { return // success, sql injection failed @@ -1583,7 +1583,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) { testData := func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v VARCHAR(255))") - // All sequences that are escaped by EscapeQuotes and EscapeString + // All sequences that are escaped by escapeQuotes and escapeBackslash v := "foo \x00\n\r\x1a\"'\\" dbt.mustExec("INSERT INTO test VALUES (?)", v) diff --git a/utils.go b/utils.go index e7b15f1ff..267e97ca0 100644 --- a/utils.go +++ b/utils.go @@ -812,7 +812,7 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { // characters, and turning others into specific escape sequences, such as // turning newlines into \n and null bytes into \0. // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 -func escapeString(buf, v []byte) []byte { +func escapeBackslash(buf, v []byte) []byte { pos := len(buf) end := pos + len(v)*2 if cap(buf) < end { diff --git a/utils_test.go b/utils_test.go index 95b91964e..7583efeea 100644 --- a/utils_test.go +++ b/utils_test.go @@ -253,9 +253,9 @@ func TestFormatBinaryDateTime(t *testing.T) { expect("1978-12-30 15:46:23.987654", 11, 26) } -func TestEscapeString(t *testing.T) { +func TestEscapeBackslash(t *testing.T) { expect := func(expected, value string) { - actual := string(escapeString([]byte{}, []byte(value))) + actual := string(escapeBackslash([]byte{}, []byte(value))) if actual != expected { t.Errorf( "expected %s, got %s", From 88aeb98098739c011a27f086b079e3fea2ff3e55 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 10 Feb 2015 15:42:11 +0900 Subject: [PATCH 21/37] append string... to []byte without cast. --- connection.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/connection.go b/connection.go index 55fbf930e..d574eae41 100644 --- a/connection.go +++ b/connection.go @@ -210,17 +210,19 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin argPos := 0 for i := 0; i < len(query); i++ { - c := query[i] - if c != '?' { - buf = append(buf, c) - continue + q := strings.IndexByte(query[i:], '?') + if q == -1 { + buf = append(buf, query[i:]...) + break } + buf = append(buf, query[i:i+q]...) + i += q arg := args[argPos] argPos++ if arg == nil { - buf = append(buf, []byte("NULL")...) + buf = append(buf, "NULL"...) continue } @@ -237,7 +239,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } case time.Time: if v.IsZero() { - buf = append(buf, []byte("'0000-00-00'")...) + buf = append(buf, "'0000-00-00'"...) } else { v := v.In(mc.cfg.loc) year := v.Year() @@ -286,7 +288,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } case []byte: if v == nil { - buf = append(buf, []byte("NULL")...) + buf = append(buf, "NULL"...) } else { buf = mc.escapeBytes(buf, v) } From 43536c7d6d53670dcd46bd192625ff5603248c5a Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 10 Feb 2015 15:52:39 +0900 Subject: [PATCH 22/37] Specialize escape functions for string benchmark old ns/op new ns/op delta BenchmarkInterpolation 2463 2118 -14.01% benchmark old allocs new allocs delta BenchmarkInterpolation 3 2 -33.33% benchmark old bytes new bytes delta BenchmarkInterpolation 496 448 -9.68% --- connection.go | 16 ++++++++-- utils.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++++--- utils_test.go | 4 +-- 3 files changed, 96 insertions(+), 9 deletions(-) diff --git a/connection.go b/connection.go index d574eae41..fe6391c95 100644 --- a/connection.go +++ b/connection.go @@ -169,9 +169,19 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte { buf = append(buf, '\'') if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeBackslash(buf, v) + buf = escapeBytesBackslash(buf, v) } else { - buf = escapeQuotes(buf, v) + buf = escapeBytesQuotes(buf, v) + } + return append(buf, '\'') +} + +func (mc *mysqlConn) escapeString(buf []byte, v string) []byte { + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeStringBackslash(buf, v) + } else { + buf = escapeStringQuotes(buf, v) } return append(buf, '\'') } @@ -293,7 +303,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf = mc.escapeBytes(buf, v) } case string: - buf = mc.escapeBytes(buf, []byte(v)) + buf = mc.escapeString(buf, v) default: return "", driver.ErrSkip } diff --git a/utils.go b/utils.go index 267e97ca0..2b7d7f3dd 100644 --- a/utils.go +++ b/utils.go @@ -807,12 +807,12 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } -// Escape string with backslashes (\) +// escapeBytesBackslash escapes []byte with backslashes (\) // This escapes the contents of a string (provided as []byte) by adding backslashes before special // characters, and turning others into specific escape sequences, such as // turning newlines into \n and null bytes into \0. // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 -func escapeBackslash(buf, v []byte) []byte { +func escapeBytesBackslash(buf, v []byte) []byte { pos := len(buf) end := pos + len(v)*2 if cap(buf) < end { @@ -861,12 +861,63 @@ func escapeBackslash(buf, v []byte) []byte { return buf[:pos] } -// Escape apostrophes by doubling them up +// escapeStringBackslash is similar to escapeBytesBackslash but for string. +func escapeStringBackslash(buf []byte, v string) []byte { + pos := len(buf) + end := pos + len(v)*2 + if cap(buf) < end { + n := make([]byte, pos+end) + copy(n, buf) + buf = n + } + buf = buf[0:end] + + for i := 0; i < len(v); i++ { + c := v[i] + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos += 1 + } + } + + return buf[:pos] +} + +// escapeBytesQuotes escapes apostrophes in []byte by doubling them up. // This escapes the contents of a string by doubling up any apostrophes that // it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in // effect on the server. // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 -func escapeQuotes(buf, v []byte) []byte { +func escapeBytesQuotes(buf, v []byte) []byte { pos := len(buf) end := pos + len(v)*2 if cap(buf) < end { @@ -889,3 +940,29 @@ func escapeQuotes(buf, v []byte) []byte { return buf[:pos] } + +// escapeStringQuotes is similar to escapeBytesQuotes but for string. +func escapeStringQuotes(buf []byte, v string) []byte { + pos := len(buf) + end := pos + len(v)*2 + if cap(buf) < end { + n := make([]byte, pos+end) + copy(n, buf) + buf = n + } + buf = buf[0:end] + + for i := 0; i < len(v); i++ { + c := v[i] + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} diff --git a/utils_test.go b/utils_test.go index 7583efeea..80056cb08 100644 --- a/utils_test.go +++ b/utils_test.go @@ -255,7 +255,7 @@ func TestFormatBinaryDateTime(t *testing.T) { func TestEscapeBackslash(t *testing.T) { expect := func(expected, value string) { - actual := string(escapeBackslash([]byte{}, []byte(value))) + actual := string(escapeBytesBackslash([]byte{}, []byte(value))) if actual != expected { t.Errorf( "expected %s, got %s", @@ -275,7 +275,7 @@ func TestEscapeBackslash(t *testing.T) { func TestEscapeQuotes(t *testing.T) { expect := func(expected, value string) { - actual := string(escapeQuotes([]byte{}, []byte(value))) + actual := string(escapeBytesQuotes([]byte{}, []byte(value))) if actual != expected { t.Errorf( "expected %s, got %s", From 0c7ae4638c79b6f4662150a715a8a1b80a818111 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 10 Feb 2015 15:55:29 +0900 Subject: [PATCH 23/37] test for escapeString* --- utils_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/utils_test.go b/utils_test.go index 80056cb08..34b3cdf7b 100644 --- a/utils_test.go +++ b/utils_test.go @@ -262,6 +262,14 @@ func TestEscapeBackslash(t *testing.T) { expected, actual, ) } + + actual = string(escapeStringBackslash([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } } expect("foo\\0bar", "foo\x00bar") @@ -282,6 +290,14 @@ func TestEscapeQuotes(t *testing.T) { expected, actual, ) } + + actual = string(escapeStringQuotes([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } } expect("foo\x00bar", "foo\x00bar") // not affected From c285e39201651851cda7bce17910e8d9dad36082 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 10 Feb 2015 16:36:14 +0900 Subject: [PATCH 24/37] Use digits10 and digits01 to format datetime. --- connection.go | 49 ++++++++++++++++++++++--------------------------- utils.go | 5 +++-- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/connection.go b/connection.go index fe6391c95..1f004e13b 100644 --- a/connection.go +++ b/connection.go @@ -253,6 +253,8 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } else { v := v.In(mc.cfg.loc) year := v.Year() + year100 := year / 100 + year1 := year % 100 month := v.Month() day := v.Day() hour := v.Hour() @@ -261,37 +263,30 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin micro := v.Nanosecond() / 1000 buf = append(buf, []byte{ - byte('\''), - byte('0' + year/1000), - byte('0' + year/100%10), - byte('0' + year/10%10), - byte('0' + year%10), - byte('-'), - byte('0' + month/10), - byte('0' + month%10), - byte('-'), - byte('0' + day/10), - byte('0' + day%10), - byte(' '), - byte('0' + hour/10), - byte('0' + hour%10), - byte(':'), - byte('0' + minute/10), - byte('0' + minute%10), - byte(':'), - byte('0' + second/10), - byte('0' + second%10), + '\'', + digits10[year100], digits01[year100], + digits10[year1], digits01[year1], + '-', + digits10[month], digits01[month], + '-', + digits10[day], digits01[day], + ' ', + digits10[hour], digits01[hour], + ':', + digits10[minute], digits01[minute], + ':', + digits10[second], digits01[second], }...) if micro != 0 { + micro10000 := micro / 10000 + micro100 := micro / 100 % 100 + micro1 := micro % 100 buf = append(buf, []byte{ - byte('.'), - byte('0' + micro/100000), - byte('0' + micro/10000%10), - byte('0' + micro/1000%10), - byte('0' + micro/100%10), - byte('0' + micro/10%10), - byte('0' + micro%10), + '.', + digits10[micro10000], digits01[micro10000], + digits10[micro100], digits01[micro100], + digits10[micro1], digits01[micro1], }...) } buf = append(buf, '\'') diff --git a/utils.go b/utils.go index 2b7d7f3dd..b1e163874 100644 --- a/utils.go +++ b/utils.go @@ -540,11 +540,12 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va // The current behavior depends on database/sql copying the result. var zeroDateTime = []byte("0000-00-00 00:00:00.000000") +const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" + func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { // length expects the deterministic length of the zero value, // negative time and 100+ hours are automatically added if needed - const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" - const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" if len(src) == 0 { if justTime { return zeroDateTime[11 : 11+length], nil From fcea44760c54714db186d4e677209cb8e14531f7 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 10 Feb 2015 16:37:05 +0900 Subject: [PATCH 25/37] Round under microsecond --- connection.go | 1 + 1 file changed, 1 insertion(+) diff --git a/connection.go b/connection.go index 1f004e13b..f018e3f1b 100644 --- a/connection.go +++ b/connection.go @@ -252,6 +252,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf = append(buf, "'0000-00-00'"...) } else { v := v.In(mc.cfg.loc) + v = v.Add(time.Nanosecond * 500) // To round under microsecond year := v.Year() year100 := year / 100 year1 := year % 100 From bfbe6c59bb5eb5f0498fa929296a50395949b196 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 01:41:57 +0900 Subject: [PATCH 26/37] travis: Drop Go 1.1 and add Go 1.4 --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index dd29a7580..50eb041fa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: - - 1.1 - 1.2 - 1.3 + - 1.4 - tip before_script: From d65f96afcc4a4d711e4bceff75d7c3db77d273b6 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 18:36:52 +0900 Subject: [PATCH 27/37] Fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1d008cf7a..58c437a7e 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ Valid Values: true, false Default: false ``` -When `interpolateParams` is true, calls to `sql.Db.Query()` and `sql.Db.Exec()` with params interpolates placeholders (`?`) with given params. This reduces roundtrips to database compared with `interpolateParams=false` since it uses prapre, exec and close to support parameters. +When `interpolateParams` is true, calls to `sql.Db.Query()` and `sql.Db.Exec()` with params interpolates placeholders (`?`) with given params. This reduces roundtrips to database compared with `interpolateParams=false` since it uses prepare, exec and close to support parameters. NOTE: It make SQL injection vulnerability when connection encoding is not utf8. From e11c82531601d8eefe873df5d10147f8040cb3b1 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 18:45:19 +0900 Subject: [PATCH 28/37] Inlining mysqlConn.escapeBytes and mysqlConn.escapeString --- connection.go | 37 ++++++++++++++----------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/connection.go b/connection.go index f018e3f1b..80da72d1d 100644 --- a/connection.go +++ b/connection.go @@ -165,27 +165,6 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { return stmt, err } -// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156 -func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte { - buf = append(buf, '\'') - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeBytesBackslash(buf, v) - } else { - buf = escapeBytesQuotes(buf, v) - } - return append(buf, '\'') -} - -func (mc *mysqlConn) escapeString(buf []byte, v string) []byte { - buf = append(buf, '\'') - if mc.status&statusNoBackslashEscapes == 0 { - buf = escapeStringBackslash(buf, v) - } else { - buf = escapeStringQuotes(buf, v) - } - return append(buf, '\'') -} - // estimateParamLength calculates upper bound of string length from types. func estimateParamLength(args []driver.Value) (int, bool) { l := 0 @@ -296,10 +275,22 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if v == nil { buf = append(buf, "NULL"...) } else { - buf = mc.escapeBytes(buf, v) + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') } case string: - buf = mc.escapeString(buf, v) + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeStringBackslash(buf, v) + } else { + buf = escapeStringQuotes(buf, v) + } + buf = append(buf, '\'') default: return "", driver.ErrSkip } From b4f0315a646c33f324f42111722856d223469820 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 18:47:53 +0900 Subject: [PATCH 29/37] Bit detailed info about vulnerability when using multibyte encoding. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 58c437a7e..8754ea2e4 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,7 @@ Default: false When `interpolateParams` is true, calls to `sql.Db.Query()` and `sql.Db.Exec()` with params interpolates placeholders (`?`) with given params. This reduces roundtrips to database compared with `interpolateParams=false` since it uses prepare, exec and close to support parameters. -NOTE: It make SQL injection vulnerability when connection encoding is not utf8. +NOTE: It make SQL injection vulnerability when connection encoding is multibyte encoding except utf-8 (e.g. cp932). ##### `loc` From 1fd051484e211c89c8453bf59eeec54cc96151c7 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 21:05:36 +0900 Subject: [PATCH 30/37] Add link to StackOverflow describe vulnerability using multibyte encoding --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8754ea2e4..282d83dfc 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,8 @@ Default: false When `interpolateParams` is true, calls to `sql.Db.Query()` and `sql.Db.Exec()` with params interpolates placeholders (`?`) with given params. This reduces roundtrips to database compared with `interpolateParams=false` since it uses prepare, exec and close to support parameters. -NOTE: It make SQL injection vulnerability when connection encoding is multibyte encoding except utf-8 (e.g. cp932). +NOTE: *This may introduce a SQL injection vulnerability when connection encoding is multibyte encoding except for UTF-8 (e.g. CP932)!* +(See http://stackoverflow.com/a/12118602/3430118) ##### `loc` From 20b75cd3d34c381d01d7e19f322fc4668975dfc2 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 21:06:53 +0900 Subject: [PATCH 31/37] Fix comment --- connection.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connection.go b/connection.go index 80da72d1d..82e39628e 100644 --- a/connection.go +++ b/connection.go @@ -314,7 +314,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err if !mc.cfg.interpolateParams { return nil, driver.ErrSkip } - // try client-side prepare to reduce roundtrip + // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement prepared, err := mc.interpolateParams(query, args) if err != nil { return nil, err From e517683745a9ed082fac6ada65cc61cc04519945 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 22:17:15 +0900 Subject: [PATCH 32/37] Allow interpolateParams only with ascii, latin1 and utf8 collations --- driver_test.go | 20 +++++++++++--------- utils.go | 33 ++++++++++++++++++++++++++++++--- utils_test.go | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 12 deletions(-) diff --git a/driver_test.go b/driver_test.go index bda07d63a..bb8aa0848 100644 --- a/driver_test.go +++ b/driver_test.go @@ -87,19 +87,21 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db.Exec("DROP TABLE IF EXISTS test") - dbp, err := sql.Open("mysql", dsn+"&interpolateParams=true") - if err != nil { - t.Fatalf("Error connecting: %s", err.Error()) + dsn2 := dsn + "&interpolateParams=true" + var db2 *sql.DB + if _, err := parseDSN(dsn2); err != errInvalidDSNUnsafeCollation { + db2, err = sql.Open("mysql", dsn2) } - defer dbp.Close() dbt := &DBTest{t, db} - dbtp := &DBTest{t, dbp} + dbt2 := &DBTest{t, db2} for _, test := range tests { test(dbt) dbt.db.Exec("DROP TABLE IF EXISTS test") - test(dbtp) - dbtp.db.Exec("DROP TABLE IF EXISTS test") + if db2 != nil { + test(dbt2) + dbt2.db.Exec("DROP TABLE IF EXISTS test") + } } } @@ -864,7 +866,7 @@ func TestLoadData(t *testing.T) { dbt.Fatalf("%d != %d", i, id) } if values[i-1] != value { - dbt.Fatalf("%s != %s", values[i-1], value) + dbt.Fatalf("%q != %q", values[i-1], value) } } err = rows.Err() @@ -889,7 +891,7 @@ func TestLoadData(t *testing.T) { // Local File RegisterLocalFile(file.Name()) - dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name())) + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) verifyLoadDataResult() // negative test _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") diff --git a/utils.go b/utils.go index b1e163874..4d5678d86 100644 --- a/utils.go +++ b/utils.go @@ -25,9 +25,10 @@ import ( var ( tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs - errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") - errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") - errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") + errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") + errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset") ) func init() { @@ -147,6 +148,32 @@ func parseDSN(dsn string) (cfg *config, err error) { return nil, errInvalidDSNNoSlash } + if cfg.interpolateParams && cfg.collation != defaultCollation { + // A whitelist of collations which safe to interpolate parameters. + // ASCII and latin-1 are safe since they are single byte encoding. + // utf-8 is safe since it doesn't conatins ASCII characters in trailing bytes. + safeCollations := []string{"ascii_", "latin1_", "utf8_", "utf8mb4_"} + + var collationName string + for name, collation := range collations { + if collation == cfg.collation { + collationName = name + break + } + } + + safe := false + for _, p := range safeCollations { + if strings.HasPrefix(collationName, p) { + safe = true + break + } + } + if !safe { + return nil, errInvalidDSNUnsafeCollation + } + } + // Set default network if empty if cfg.net == "" { cfg.net = "tcp" diff --git a/utils_test.go b/utils_test.go index 34b3cdf7b..adb8dcbd1 100644 --- a/utils_test.go +++ b/utils_test.go @@ -116,6 +116,43 @@ func TestDSNWithCustomTLS(t *testing.T) { DeregisterTLSConfig("utils_test") } +func TestDSNUnsafeCollation(t *testing.T) { + _, err := parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true") + if err != errInvalidDSNUnsafeCollation { + t.Error("Expected %v, Got %v", errInvalidDSNUnsafeCollation, err) + } + + _, err = parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=gbk_chinese_ci") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=ascii_bin&interpolateParams=true") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } +} + func BenchmarkParseDSN(b *testing.B) { b.ReportAllocs() From 0f22bc29c1def30c7926c5b1085e732cdb4267c7 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 22:29:53 +0900 Subject: [PATCH 33/37] extract function to reserve buffer --- utils.go | 45 +++++++++++++++++---------------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/utils.go b/utils.go index 4d5678d86..881c9df0a 100644 --- a/utils.go +++ b/utils.go @@ -835,6 +835,19 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } +// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. +// If cap(buf) is not enough, reallocate new buffer. +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + // Grow buffer exponentially + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + // escapeBytesBackslash escapes []byte with backslashes (\) // This escapes the contents of a string (provided as []byte) by adding backslashes before special // characters, and turning others into specific escape sequences, such as @@ -842,13 +855,7 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 func escapeBytesBackslash(buf, v []byte) []byte { pos := len(buf) - end := pos + len(v)*2 - if cap(buf) < end { - n := make([]byte, pos+end) - copy(n, buf) - buf = n - } - buf = buf[0:end] + buf = reserveBuffer(buf, len(v)*2) for _, c := range v { switch c { @@ -892,13 +899,7 @@ func escapeBytesBackslash(buf, v []byte) []byte { // escapeStringBackslash is similar to escapeBytesBackslash but for string. func escapeStringBackslash(buf []byte, v string) []byte { pos := len(buf) - end := pos + len(v)*2 - if cap(buf) < end { - n := make([]byte, pos+end) - copy(n, buf) - buf = n - } - buf = buf[0:end] + buf = reserveBuffer(buf, len(v)*2) for i := 0; i < len(v); i++ { c := v[i] @@ -947,13 +948,7 @@ func escapeStringBackslash(buf []byte, v string) []byte { // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 func escapeBytesQuotes(buf, v []byte) []byte { pos := len(buf) - end := pos + len(v)*2 - if cap(buf) < end { - n := make([]byte, pos+end) - copy(n, buf) - buf = n - } - buf = buf[0:end] + buf = reserveBuffer(buf, len(v)*2) for _, c := range v { if c == '\'' { @@ -972,13 +967,7 @@ func escapeBytesQuotes(buf, v []byte) []byte { // escapeStringQuotes is similar to escapeBytesQuotes but for string. func escapeStringQuotes(buf []byte, v string) []byte { pos := len(buf) - end := pos + len(v)*2 - if cap(buf) < end { - n := make([]byte, pos+end) - copy(n, buf) - buf = n - } - buf = buf[0:end] + buf = reserveBuffer(buf, len(v)*2) for i := 0; i < len(v); i++ { c := v[i] From 52a5860d0b9f019d37285b9e7053fc9a9cdde6c6 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 22:35:26 +0900 Subject: [PATCH 34/37] Fix missing db.Close() --- driver_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/driver_test.go b/driver_test.go index bb8aa0848..bda62eebc 100644 --- a/driver_test.go +++ b/driver_test.go @@ -91,6 +91,10 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { var db2 *sql.DB if _, err := parseDSN(dsn2); err != errInvalidDSNUnsafeCollation { db2, err = sql.Open("mysql", dsn2) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + defer db2.Close() } dbt := &DBTest{t, db} From 2a634df7834d92de65b0610980516aad660c2ec9 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 12 Feb 2015 22:42:31 +0900 Subject: [PATCH 35/37] Fix sentence in interpolateParams document. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 282d83dfc..74b224ce3 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ Valid Values: true, false Default: false ``` -When `interpolateParams` is true, calls to `sql.Db.Query()` and `sql.Db.Exec()` with params interpolates placeholders (`?`) with given params. This reduces roundtrips to database compared with `interpolateParams=false` since it uses prepare, exec and close to support parameters. +If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. NOTE: *This may introduce a SQL injection vulnerability when connection encoding is multibyte encoding except for UTF-8 (e.g. CP932)!* (See http://stackoverflow.com/a/12118602/3430118) From 90cb6c31d5ddfaec6fadf6d10ae08461e6bd4d35 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Fri, 13 Feb 2015 00:41:14 +0900 Subject: [PATCH 36/37] Use blacklist to avoid vulnerability with interpolation --- collations.go | 14 ++++++++++++++ utils.go | 26 ++------------------------ 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/collations.go b/collations.go index aabe0055d..6c1d613d5 100644 --- a/collations.go +++ b/collations.go @@ -234,3 +234,17 @@ var collations = map[string]byte{ "utf8mb4_unicode_520_ci": 246, "utf8mb4_vietnamese_ci": 247, } + +// A blacklist of collations which is unsafe to interpolate parameters. +// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. +var unsafeCollations = map[byte]bool{ + 1: true, // big5_chinese_ci + 13: true, // sjis_japanese_ci + 28: true, // gbk_chinese_ci + 84: true, // big5_bin + 86: true, // gb2312_bin + 87: true, // gbk_bin + 88: true, // sjis_bin + 95: true, // cp932_japanese_ci + 96: true, // cp932_bin +} diff --git a/utils.go b/utils.go index 881c9df0a..6693d2970 100644 --- a/utils.go +++ b/utils.go @@ -148,30 +148,8 @@ func parseDSN(dsn string) (cfg *config, err error) { return nil, errInvalidDSNNoSlash } - if cfg.interpolateParams && cfg.collation != defaultCollation { - // A whitelist of collations which safe to interpolate parameters. - // ASCII and latin-1 are safe since they are single byte encoding. - // utf-8 is safe since it doesn't conatins ASCII characters in trailing bytes. - safeCollations := []string{"ascii_", "latin1_", "utf8_", "utf8mb4_"} - - var collationName string - for name, collation := range collations { - if collation == cfg.collation { - collationName = name - break - } - } - - safe := false - for _, p := range safeCollations { - if strings.HasPrefix(collationName, p) { - safe = true - break - } - } - if !safe { - return nil, errInvalidDSNUnsafeCollation - } + if cfg.interpolateParams && unsafeCollations[cfg.collation] { + return nil, errInvalidDSNUnsafeCollation } // Set default network if empty From 9437b61eed49762ab7c01179599f8fd8e7170af6 Mon Sep 17 00:00:00 2001 From: arvenil Date: Fri, 13 Feb 2015 13:20:16 +0100 Subject: [PATCH 37/37] Adding myself to AUTHORS (however, 99% work done by @methane ;)) --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index 672e088ae..b6869d4cb 100644 --- a/AUTHORS +++ b/AUTHORS @@ -24,6 +24,7 @@ INADA Naoki James Harr Jian Zhen Julien Schmidt +Kamil Dziedzic Leonardo YongUk Kim Lucas Liu Luke Scott