diff --git a/.travis.yml b/.travis.yml index dd29a7580..50eb041fa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: - - 1.1 - 1.2 - 1.3 + - 1.4 - tip before_script: diff --git a/AUTHORS b/AUTHORS index 672e088ae..b6869d4cb 100644 --- a/AUTHORS +++ b/AUTHORS @@ -24,6 +24,7 @@ INADA Naoki James Harr Jian Zhen Julien Schmidt +Kamil Dziedzic Leonardo YongUk Kim Lucas Liu Luke Scott diff --git a/README.md b/README.md index a42d3eb39..74b224ce3 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,19 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `interpolateParams` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. + +NOTE: *This may introduce a SQL injection vulnerability when connection encoding is multibyte encoding except for UTF-8 (e.g. CP932)!* +(See http://stackoverflow.com/a/12118602/3430118) + ##### `loc` ``` diff --git a/benchmark_test.go b/benchmark_test.go index d72a4183f..623ba28fa 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -11,10 +11,13 @@ package mysql import ( "bytes" "database/sql" + "database/sql/driver" + "math" "strings" "sync" "sync/atomic" "testing" + "time" ) type TB testing.B @@ -45,7 +48,11 @@ func initDB(b *testing.B, queries ...string) *sql.DB { db := tb.checkDB(sql.Open("mysql", dsn)) for _, query := range queries { if _, err := db.Exec(query); err != nil { - b.Fatalf("Error on %q: %v", query, err) + if w, ok := err.(MySQLWarnings); ok { + b.Logf("Warning on %q: %v", query, w) + } else { + b.Fatalf("Error on %q: %v", query, err) + } } } return db @@ -206,3 +213,33 @@ func BenchmarkRoundtripBin(b *testing.B) { rows.Close() } } + +func BenchmarkInterpolation(b *testing.B) { + mc := &mysqlConn{ + cfg: &config{ + interpolateParams: true, + loc: time.UTC, + }, + maxPacketAllowed: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + } + + args := []driver.Value{ + int64(42424242), + float64(math.Pi), + false, + time.Unix(1423411542, 807015000), + []byte("bytes containing special chars ' \" \a \x00"), + "string containing special chars ' \" \a \x00", + } + q := "SELECT ?, ?, ?, ?, ?, ?" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := mc.interpolateParams(q, args) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/collations.go b/collations.go index aabe0055d..6c1d613d5 100644 --- a/collations.go +++ b/collations.go @@ -234,3 +234,17 @@ var collations = map[string]byte{ "utf8mb4_unicode_520_ci": 246, "utf8mb4_vietnamese_ci": 247, } + +// A blacklist of collations which is unsafe to interpolate parameters. +// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. +var unsafeCollations = map[byte]bool{ + 1: true, // big5_chinese_ci + 13: true, // sjis_japanese_ci + 28: true, // gbk_chinese_ci + 84: true, // big5_bin + 86: true, // gb2312_bin + 87: true, // gbk_bin + 88: true, // sjis_bin + 95: true, // cp932_japanese_ci + 96: true, // cp932_bin +} diff --git a/connection.go b/connection.go index 264735632..82e39628e 100644 --- a/connection.go +++ b/connection.go @@ -13,6 +13,7 @@ import ( "database/sql/driver" "errors" "net" + "strconv" "strings" "time" ) @@ -26,6 +27,7 @@ type mysqlConn struct { maxPacketAllowed int maxWriteSize int flags clientFlag + status statusFlag sequence uint8 parseTime bool strict bool @@ -46,6 +48,7 @@ type config struct { allowOldPasswords bool clientFoundRows bool columnsWithAlias bool + interpolateParams bool } // Handles parameters set in DSN after the connection is established @@ -162,28 +165,174 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { return stmt, err } +// estimateParamLength calculates upper bound of string length from types. +func estimateParamLength(args []driver.Value) (int, bool) { + l := 0 + for _, a := range args { + switch v := a.(type) { + case int64, float64: + // 24 (-1.7976931348623157e+308) may be upper bound. But I'm not sure. + l += 25 + case bool: + l += 1 // 0 or 1 + case time.Time: + l += 30 // '1234-12-23 12:34:56.777777' + case string: + l += len(v)*2 + 2 + case []byte: + l += len(v)*2 + 2 + default: + return 0, false + } + } + return l, true +} + +func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { + estimated, ok := estimateParamLength(args) + if !ok { + return "", driver.ErrSkip + } + estimated += len(query) + + buf := make([]byte, 0, estimated) + argPos := 0 + + for i := 0; i < len(query); i++ { + q := strings.IndexByte(query[i:], '?') + if q == -1 { + buf = append(buf, query[i:]...) + break + } + buf = append(buf, query[i:i+q]...) + i += q + + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + continue + } + + switch v := arg.(type) { + case int64: + buf = strconv.AppendInt(buf, v, 10) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + v := v.In(mc.cfg.loc) + v = v.Add(time.Nanosecond * 500) // To round under microsecond + year := v.Year() + year100 := year / 100 + year1 := year % 100 + month := v.Month() + day := v.Day() + hour := v.Hour() + minute := v.Minute() + second := v.Second() + micro := v.Nanosecond() / 1000 + + buf = append(buf, []byte{ + '\'', + digits10[year100], digits01[year100], + digits10[year1], digits01[year1], + '-', + digits10[month], digits01[month], + '-', + digits10[day], digits01[day], + ' ', + digits10[hour], digits01[hour], + ':', + digits10[minute], digits01[minute], + ':', + digits10[second], digits01[second], + }...) + + if micro != 0 { + micro10000 := micro / 10000 + micro100 := micro / 100 % 100 + micro1 := micro % 100 + buf = append(buf, []byte{ + '.', + digits10[micro10000], digits01[micro10000], + digits10[micro100], digits01[micro100], + digits10[micro1], digits01[micro1], + }...) + } + buf = append(buf, '\'') + } + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeStringBackslash(buf, v) + } else { + buf = escapeStringQuotes(buf, v) + } + buf = append(buf, '\'') + default: + return "", driver.ErrSkip + } + + if len(buf)+4 > mc.maxPacketAllowed { + return "", driver.ErrSkip + } + } + if argPos != len(args) { + return "", driver.ErrSkip + } + return string(buf), nil +} + func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - if len(args) == 0 { // no args, fastpath - mc.affectedRows = 0 - mc.insertId = 0 - - err := mc.exec(query) - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, err + if len(args) != 0 { + if !mc.cfg.interpolateParams { + return nil, driver.ErrSkip } - return nil, err + // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + args = nil } + mc.affectedRows = 0 + mc.insertId = 0 - // with args, must use prepared stmt - return nil, driver.ErrSkip - + err := mc.exec(query) + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, err + } + return nil, err } // Internal function to execute commands @@ -212,31 +361,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - if len(args) == 0 { // no args, fastpath - // Send command - err := mc.writeCommandPacketStr(comQuery, query) + if len(args) != 0 { + if !mc.cfg.interpolateParams { + return nil, driver.ErrSkip + } + // try client-side prepare to reduce roundtrip + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + args = nil + } + // Send command + err := mc.writeCommandPacketStr(comQuery, query) + if err == nil { + // Read Result + var resLen int + resLen, err = mc.readResultSetHeaderPacket() if err == nil { - // Read Result - var resLen int - resLen, err = mc.readResultSetHeaderPacket() - if err == nil { - rows := new(textRows) - rows.mc = mc - - if resLen == 0 { - // no columns, no more data - return emptyRows{}, nil - } - // Columns - rows.columns, err = mc.readColumns(resLen) - return rows, err + rows := new(textRows) + rows.mc = mc + + if resLen == 0 { + // no columns, no more data + return emptyRows{}, nil } + // Columns + rows.columns, err = mc.readColumns(resLen) + return rows, err } - return nil, err } - - // with args, must use prepared stmt - return nil, driver.ErrSkip + return nil, err } // Gets the value of the given MySQL System Variable diff --git a/const.go b/const.go index 5fcc3e98b..7bf5cea3d 100644 --- a/const.go +++ b/const.go @@ -130,3 +130,25 @@ const ( flagUnknown3 flagUnknown4 ) + +// http://dev.mysql.com/doc/internals/en/status-flags.html + +type statusFlag uint16 + +const ( + statusInTrans statusFlag = 1 << iota + statusInAutocommit + statusReserved // Not in documentation + statusMoreResultsExists + statusNoGoodIndexUsed + statusNoIndexUsed + statusCursorExists + statusLastRowSent + statusDbDropped + statusNoBackslashEscapes + statusMetadataChanged + statusQueryWasSlow + statusPsOutParams + statusInTransReadonly + statusSessionStateChanged +) diff --git a/driver_test.go b/driver_test.go index a52cc5cd0..bda62eebc 100644 --- a/driver_test.go +++ b/driver_test.go @@ -87,10 +87,25 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db.Exec("DROP TABLE IF EXISTS test") + dsn2 := dsn + "&interpolateParams=true" + var db2 *sql.DB + if _, err := parseDSN(dsn2); err != errInvalidDSNUnsafeCollation { + db2, err = sql.Open("mysql", dsn2) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + defer db2.Close() + } + dbt := &DBTest{t, db} + dbt2 := &DBTest{t, db2} for _, test := range tests { test(dbt) dbt.db.Exec("DROP TABLE IF EXISTS test") + if db2 != nil { + test(dbt2) + dbt2.db.Exec("DROP TABLE IF EXISTS test") + } } } @@ -855,7 +870,7 @@ func TestLoadData(t *testing.T) { dbt.Fatalf("%d != %d", i, id) } if values[i-1] != value { - dbt.Fatalf("%s != %s", values[i-1], value) + dbt.Fatalf("%q != %q", values[i-1], value) } } err = rows.Err() @@ -880,7 +895,7 @@ func TestLoadData(t *testing.T) { // Local File RegisterLocalFile(file.Name()) - dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name())) + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) verifyLoadDataResult() // negative test _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") @@ -1538,3 +1553,62 @@ func TestCustomDial(t *testing.T) { t.Fatalf("Connection failed: %s", err.Error()) } } + +func TestSqlInjection(t *testing.T) { + createTest := func(arg string) func(dbt *DBTest) { + return func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (?)", 1) + + var v int + // NULL can't be equal to anything, the idea here is to inject query so it returns row + // This test verifies that escapeQuotes and escapeBackslash are working properly + err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) + if err == sql.ErrNoRows { + return // success, sql injection failed + } else if err == nil { + dbt.Errorf("Sql injection successful with arg: %s", arg) + } else { + dbt.Errorf("Error running query with arg: %s; err: %s", arg, err.Error()) + } + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode=NO_BACKSLASH_ESCAPES", + } + for _, testdsn := range dsns { + runTests(t, testdsn, createTest("1 OR 1=1")) + runTests(t, testdsn, createTest("' OR '1'='1")) + } +} + +// Test if inserted data is correctly retrieved after being escaped +func TestInsertRetrieveEscapedData(t *testing.T) { + testData := func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v VARCHAR(255))") + + // All sequences that are escaped by escapeQuotes and escapeBackslash + v := "foo \x00\n\r\x1a\"'\\" + dbt.mustExec("INSERT INTO test VALUES (?)", v) + + var out string + err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + if out != v { + dbt.Errorf("%q != %q", out, v) + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode=NO_BACKSLASH_ESCAPES", + } + for _, testdsn := range dsns { + runTests(t, testdsn, testData) + } +} diff --git a/packets.go b/packets.go index 255301e81..f39d65c70 100644 --- a/packets.go +++ b/packets.go @@ -484,6 +484,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] + mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8 // warning count [2 bytes] if !mc.strict { diff --git a/utils.go b/utils.go index 56ed45fb3..6693d2970 100644 --- a/utils.go +++ b/utils.go @@ -25,9 +25,10 @@ import ( var ( tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs - errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") - errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") - errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") + errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") + errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset") ) func init() { @@ -147,6 +148,10 @@ func parseDSN(dsn string) (cfg *config, err error) { return nil, errInvalidDSNNoSlash } + if cfg.interpolateParams && unsafeCollations[cfg.collation] { + return nil, errInvalidDSNUnsafeCollation + } + // Set default network if empty if cfg.net == "" { cfg.net = "tcp" @@ -180,6 +185,14 @@ func parseDSNParams(cfg *config, params string) (err error) { // cfg params switch value := param[1]; param[0] { + // Enable client side placeholder substitution + case "interpolateParams": + var isBool bool + cfg.interpolateParams, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool @@ -216,7 +229,7 @@ func parseDSNParams(cfg *config, params string) (err error) { } cfg.collation = collation break - + case "columnsWithAlias": var isBool bool cfg.columnsWithAlias, isBool = readBool(value) @@ -532,11 +545,12 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va // The current behavior depends on database/sql copying the result. var zeroDateTime = []byte("0000-00-00 00:00:00.000000") +const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" + func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { // length expects the deterministic length of the zero value, // negative time and 100+ hours are automatically added if needed - const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" - const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" if len(src) == 0 { if justTime { return zeroDateTime[11 : 11+length], nil @@ -798,3 +812,152 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } + +// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. +// If cap(buf) is not enough, reallocate new buffer. +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + // Grow buffer exponentially + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + +// escapeBytesBackslash escapes []byte with backslashes (\) +// This escapes the contents of a string (provided as []byte) by adding backslashes before special +// characters, and turning others into specific escape sequences, such as +// turning newlines into \n and null bytes into \0. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 +func escapeBytesBackslash(buf, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos += 1 + } + } + + return buf[:pos] +} + +// escapeStringBackslash is similar to escapeBytesBackslash but for string. +func escapeStringBackslash(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for i := 0; i < len(v); i++ { + c := v[i] + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos += 1 + } + } + + return buf[:pos] +} + +// escapeBytesQuotes escapes apostrophes in []byte by doubling them up. +// This escapes the contents of a string by doubling up any apostrophes that +// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in +// effect on the server. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 +func escapeBytesQuotes(buf, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringQuotes is similar to escapeBytesQuotes but for string. +func escapeStringQuotes(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for i := 0; i < len(v); i++ { + c := v[i] + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} diff --git a/utils_test.go b/utils_test.go index d45f8c360..adb8dcbd1 100644 --- a/utils_test.go +++ b/utils_test.go @@ -22,19 +22,19 @@ var testDSNs = []struct { out string loc *time.Location }{ - {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, - {"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true}", time.UTC}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, - {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false}", time.UTC}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.Local}, - {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, - {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, - {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, - {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, - {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, - {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC}, + {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true interpolateParams:false}", time.UTC}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.Local}, + {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, + {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, } func TestDSNParser(t *testing.T) { @@ -116,6 +116,43 @@ func TestDSNWithCustomTLS(t *testing.T) { DeregisterTLSConfig("utils_test") } +func TestDSNUnsafeCollation(t *testing.T) { + _, err := parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true") + if err != errInvalidDSNUnsafeCollation { + t.Error("Expected %v, Got %v", errInvalidDSNUnsafeCollation, err) + } + + _, err = parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=gbk_chinese_ci") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=ascii_bin&interpolateParams=true") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } + + _, err = parseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true") + if err != nil { + t.Error("Expected %v, Got %v", nil, err) + } +} + func BenchmarkParseDSN(b *testing.B) { b.ReportAllocs() @@ -252,3 +289,58 @@ func TestFormatBinaryDateTime(t *testing.T) { expect("1978-12-30 15:46:23", 7, 19) expect("1978-12-30 15:46:23.987654", 11, 26) } + +func TestEscapeBackslash(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesBackslash([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringBackslash([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\\0bar", "foo\x00bar") + expect("foo\\nbar", "foo\nbar") + expect("foo\\rbar", "foo\rbar") + expect("foo\\Zbar", "foo\x1abar") + expect("foo\\\"bar", "foo\"bar") + expect("foo\\\\bar", "foo\\bar") + expect("foo\\'bar", "foo'bar") +} + +func TestEscapeQuotes(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesQuotes([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringQuotes([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\x00bar", "foo\x00bar") // not affected + expect("foo\nbar", "foo\nbar") // not affected + expect("foo\rbar", "foo\rbar") // not affected + expect("foo\x1abar", "foo\x1abar") // not affected + expect("foo''bar", "foo'bar") // affected + expect("foo\"bar", "foo\"bar") // not affected +}