diff --git a/connection.go b/connection.go index 67d3dbee8..b51616cce 100644 --- a/connection.go +++ b/connection.go @@ -13,6 +13,7 @@ import ( "database/sql/driver" "errors" "net" + "strconv" "strings" "time" ) @@ -26,25 +27,27 @@ type mysqlConn struct { maxPacketAllowed int maxWriteSize int flags clientFlag + status statusFlag sequence uint8 parseTime bool strict bool } type config struct { - user string - passwd string - net string - addr string - dbname string - params map[string]string - loc *time.Location - tls *tls.Config - timeout time.Duration - collation uint8 - allowAllFiles bool - allowOldPasswords bool - clientFoundRows bool + user string + passwd string + net string + addr string + dbname string + params map[string]string + loc *time.Location + tls *tls.Config + timeout time.Duration + collation uint8 + allowAllFiles bool + allowOldPasswords bool + clientFoundRows bool + substitutePlaceholder bool } // Handles parameters set in DSN after the connection is established @@ -161,28 +164,146 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { return stmt, err } +func (mc *mysqlConn) escapeBytes(v []byte) string { + buf := make([]byte, len(v)*2+2) + buf[0] = '\'' + pos := 1 + if mc.status&statusNoBackslashEscapes == 0 { + for _, c := range v { + switch c { + case '\x00': + buf[pos] = '\\' + buf[pos+1] = '0' + pos += 2 + case '\n': + buf[pos] = '\\' + buf[pos+1] = 'n' + pos += 2 + case '\r': + buf[pos] = '\\' + buf[pos+1] = 'r' + pos += 2 + case '\x1a': + buf[pos] = '\\' + buf[pos+1] = 'Z' + pos += 2 + case '\'': + buf[pos] = '\\' + buf[pos+1] = '\'' + pos += 2 + case '"': + buf[pos] = '\\' + buf[pos+1] = '"' + pos += 2 + case '\\': + buf[pos] = '\\' + buf[pos+1] = '\\' + pos += 2 + default: + buf[pos] = c + pos += 1 + } + } + } else { + for _, c := range v { + if c == '\'' { + buf[pos] = '\'' + buf[pos+1] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + } + buf[pos] = '\'' + return string(buf[:pos+1]) +} + +func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) { + chunks := strings.Split(query, "?") + if len(chunks) != len(args)+1 { + return "", driver.ErrSkip + } + + parts := make([]string, len(chunks)+len(args)) + parts[0] = chunks[0] + + for i, arg := range args { + pos := i*2 + 1 + parts[pos+1] = chunks[i+1] + if arg == nil { + parts[pos] = "NULL" + continue + } + switch v := arg.(type) { + case int64: + parts[pos] = strconv.FormatInt(v, 10) + case float64: + parts[pos] = strconv.FormatFloat(v, 'f', -1, 64) + case bool: + if v { + parts[pos] = "1" + } else { + parts[pos] = "0" + } + case time.Time: + if v.IsZero() { + parts[pos] = "'0000-00-00'" + } else { + fmt := "'2006-01-02 15:04:05.999999'" + parts[pos] = v.In(mc.cfg.loc).Format(fmt) + } + case []byte: + if v == nil { + parts[pos] = "NULL" + } else { + parts[pos] = mc.escapeBytes(v) + } + case string: + parts[pos] = mc.escapeBytes([]byte(v)) + default: + return "", driver.ErrSkip + } + } + pktSize := len(query) + 4 // 4 bytes for header. + for _, p := range parts { + pktSize += len(p) + } + if pktSize > mc.maxPacketAllowed { + return "", driver.ErrSkip + } + return strings.Join(parts, ""), nil +} + func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - if len(args) == 0 { // no args, fastpath - mc.affectedRows = 0 - mc.insertId = 0 - - err := mc.exec(query) - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, err + if len(args) != 0 { + if !mc.cfg.substitutePlaceholder { + return nil, driver.ErrSkip } - return nil, err + // try client-side prepare to reduce roundtrip + prepared, err := mc.buildQuery(query, args) + if err != nil { + return nil, err + } + query = prepared + args = nil } + mc.affectedRows = 0 + mc.insertId = 0 - // with args, must use prepared stmt - return nil, driver.ErrSkip - + err := mc.exec(query) + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, err + } + return nil, err } // Internal function to execute commands @@ -211,31 +332,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - if len(args) == 0 { // no args, fastpath - // Send command - err := mc.writeCommandPacketStr(comQuery, query) + if len(args) != 0 { + if !mc.cfg.substitutePlaceholder { + return nil, driver.ErrSkip + } + // try client-side prepare to reduce roundtrip + prepared, err := mc.buildQuery(query, args) + if err != nil { + return nil, err + } + query = prepared + args = nil + } + // Send command + err := mc.writeCommandPacketStr(comQuery, query) + if err == nil { + // Read Result + var resLen int + resLen, err = mc.readResultSetHeaderPacket() if err == nil { - // Read Result - var resLen int - resLen, err = mc.readResultSetHeaderPacket() - if err == nil { - rows := new(textRows) - rows.mc = mc - - if resLen == 0 { - // no columns, no more data - return emptyRows{}, nil - } - // Columns - rows.columns, err = mc.readColumns(resLen) - return rows, err + rows := new(textRows) + rows.mc = mc + + if resLen == 0 { + // no columns, no more data + return emptyRows{}, nil } + // Columns + rows.columns, err = mc.readColumns(resLen) + return rows, err } - return nil, err } - - // with args, must use prepared stmt - return nil, driver.ErrSkip + return nil, err } // Gets the value of the given MySQL System Variable diff --git a/const.go b/const.go index 5fcc3e98b..3aeaf1b1b 100644 --- a/const.go +++ b/const.go @@ -130,3 +130,25 @@ const ( flagUnknown3 flagUnknown4 ) + +// http://dev.mysql.com/doc/internals/en/status-flags.html + +type statusFlag uint16 + +const ( + statusInTrans statusFlag = 1 << iota + statusInAutocommit + statusUnknown1 + statusMoreResultsExists + statusNoGoodIndexUsed + statusNoIndexUsed + statusCursorExists + statusLastRowSent + statusDbDropped + statusNoBackslashEscapes + statusMetadataChanged + statusQueryWasSlow + statusPsOutParams + statusInTransReadonly + statusSessionStateChanged +) diff --git a/driver_test.go b/driver_test.go index a52cc5cd0..f0777d79d 100644 --- a/driver_test.go +++ b/driver_test.go @@ -87,10 +87,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db.Exec("DROP TABLE IF EXISTS test") + dbp, err := sql.Open("mysql", dsn+"&substitutePlaceholder=true") + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + defer dbp.Close() + dbt := &DBTest{t, db} + dbtp := &DBTest{t, dbp} for _, test := range tests { test(dbt) dbt.db.Exec("DROP TABLE IF EXISTS test") + test(dbtp) + dbtp.db.Exec("DROP TABLE IF EXISTS test") } } diff --git a/packets.go b/packets.go index f2e385bf8..49c6f6966 100644 --- a/packets.go +++ b/packets.go @@ -484,6 +484,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] + mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8 // warning count [2 bytes] if !mc.strict { diff --git a/utils.go b/utils.go index 98dfc6f5e..008e9abc1 100644 --- a/utils.go +++ b/utils.go @@ -180,6 +180,14 @@ func parseDSNParams(cfg *config, params string) (err error) { // cfg params switch value := param[1]; param[0] { + // Enable client side placeholder substitution + case "substitutePlaceholder": + var isBool bool + cfg.substitutePlaceholder, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool diff --git a/utils_test.go b/utils_test.go index 0855374b7..4f666cb96 100644 --- a/utils_test.go +++ b/utils_test.go @@ -22,18 +22,18 @@ var testDSNs = []struct { out string loc *time.Location }{ - {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local}, - {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, - {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC}, + {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true substitutePlaceholder:false}", time.UTC}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.Local}, + {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, + {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC}, } func TestDSNParser(t *testing.T) {