From 8cbeffa8f656fc645d53357c4705af7c4e2af53b Mon Sep 17 00:00:00 2001 From: Idhor Date: Sun, 12 Apr 2015 16:38:45 +0200 Subject: [PATCH 01/10] Enable Multi Results support and discard additional results MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - packets.go: flag clientMultiResults, update status when receiving an EOF packet, discard additional results on readRow when EOF is reached - statement.go: currently a nil rows.mc is used as an eof, don’t set it if there are no columns to avoid that Next() waits indefinitely - rows.go: discard additional results on close and avoid panic on Columns() --- packets.go | 45 +++++++++++++++++++++++++++++++++++++++++++-- rows.go | 8 +++++++- statement.go | 2 +- 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/packets.go b/packets.go index 6ac1cccea..c6589109c 100644 --- a/packets.go +++ b/packets.go @@ -224,6 +224,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientTransactions | clientLocalFiles | clientPluginAuth | + clientMultiResults | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { @@ -519,6 +520,10 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { } } +func readStatus(b []byte) statusFlag { + return statusFlag(b[0]) | statusFlag(b[1])<<8 +} + // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { @@ -533,7 +538,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] - mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8 + mc.status = readStatus(data[1+n+m : 1+n+m+2]) // warning count [2 bytes] if !mc.strict { @@ -652,6 +657,11 @@ func (rows *textRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + if err := rows.mc.discardMoreResultsIfExists(); err != nil { + return err + } rows.mc = nil return io.EOF } @@ -709,6 +719,10 @@ func (mc *mysqlConn) readUntilEOF() error { if err == nil && data[0] != iEOF { continue } + if err == nil && data[0] == iEOF && len(data) == 5 { + mc.status = readStatus(data[3:]) + } + return err // Err or EOF } } @@ -1013,6 +1027,28 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { return mc.writePacket(data) } +func (mc *mysqlConn) discardMoreResultsIfExists() error { + for mc.status&statusMoreResultsExists != 0 { + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return err + } + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { + return err + } + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } + } else { + mc.status &^= statusMoreResultsExists + } + } + return nil +} + // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html func (rows *binaryRows) readRow(dest []driver.Value) error { data, err := rows.mc.readPacket() @@ -1022,11 +1058,16 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - rows.mc = nil // EOF Packet if data[0] == iEOF && len(data) == 5 { + rows.mc.status = readStatus(data[3:]) + if err := rows.mc.discardMoreResultsIfExists(); err != nil { + return err + } + rows.mc = nil return io.EOF } + rows.mc = nil // Error otherwise return rows.mc.handleErrorPacket(data) diff --git a/rows.go b/rows.go index 5d21948ad..9853f8323 100644 --- a/rows.go +++ b/rows.go @@ -38,7 +38,7 @@ type emptyRows struct{} func (rows *mysqlRows) Columns() []string { columns := make([]string, len(rows.columns)) - if rows.mc.cfg.ColumnsWithAlias { + if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { for i := range columns { if tableName := rows.columns[i].tableName; len(tableName) > 0 { columns[i] = tableName + "." + rows.columns[i].name @@ -65,6 +65,12 @@ func (rows *mysqlRows) Close() error { // Remove unread packets from stream err := mc.readUntilEOF() + if err == nil { + if err = mc.discardMoreResultsIfExists(); err != nil { + return err + } + } + rows.mc = nil return err } diff --git a/statement.go b/statement.go index 6e869b340..ead9a6bf4 100644 --- a/statement.go +++ b/statement.go @@ -101,9 +101,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { } rows := new(binaryRows) - rows.mc = mc if resLen > 0 { + rows.mc = mc // Columns // If not cached, read them and cache them if stmt.columns == nil { From 5ce0b98124bb085dde89a228f8ba7e72c6cffdfa Mon Sep 17 00:00:00 2001 From: Idhor Date: Wed, 20 Jan 2016 16:51:34 +0100 Subject: [PATCH 02/10] Add Idhor (Luca Looz) to AUTHORS for multi-results --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index e86676ff4..3c61a1a87 100644 --- a/AUTHORS +++ b/AUTHORS @@ -31,6 +31,7 @@ Julien Schmidt Kamil Dziedzic Kevin Malachowski Leonardo YongUk Kim +Luca Looz Lucas Liu Luke Scott Michael Woolnough From 1bdf5bdf85466821c9868c489be4bfc43c803b62 Mon Sep 17 00:00:00 2001 From: Badoet Endoet Date: Mon, 11 May 2015 17:20:05 +0800 Subject: [PATCH 03/10] added the multiStatements param to the dsn parameter --- dsn.go | 9 +++++++++ packets.go | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/dsn.go b/dsn.go index 31fd530ee..2203e2291 100644 --- a/dsn.go +++ b/dsn.go @@ -46,6 +46,7 @@ type Config struct { ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names InterpolateParams bool // Interpolate placeholders into query string + MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time Strict bool // Return warnings as errors } @@ -235,6 +236,14 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } + // multiple statements in one query + case "multiStatements": + var isBool bool + cfg.MultiStatements, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + // time.Time parsing case "parseTime": var isBool bool diff --git a/packets.go b/packets.go index c6589109c..1b13013cf 100644 --- a/packets.go +++ b/packets.go @@ -236,6 +236,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientSSL } + if mc.cfg.MultiStatements { + clientFlags |= clientMultiStatements + } + // User Password scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) From 71c5db63045271a268cbf59b3a548af98e69f5a5 Mon Sep 17 00:00:00 2001 From: Badoet Endoet Date: Mon, 11 May 2015 19:48:44 +0800 Subject: [PATCH 04/10] add test for the new dsn param --- dsn_test.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/dsn_test.go b/dsn_test.go index 3e2a4b37a..d825069fc 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -19,19 +19,20 @@ var testDSNs = []struct { in string out string }{ - {"username:password@protocol(address)/dbname?param=value", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false ParseTime:false Strict:false}"}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{User:user Passwd: Net:unix Addr:/path/to/socket DBName:dbname Params:map[charset:utf8] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8mb4,utf8] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{User:user Passwd:password Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS: Timeout:30s ReadTimeout:1s WriteTimeout:1s Collation:224 AllowAllFiles:true AllowCleartextPasswords:false AllowOldPasswords:true ClientFoundRows:true ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{User:user Passwd:p@ss(word) Net:tcp Addr:[de:ad:be:ef::ca:fe]:80 DBName:dbname Params:map[] Loc:Local TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"/dbname", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"@/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"user:p@/ssword@/", "&{User:user Passwd:p@/ssword Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, - {"unix/?arg=%2Fsome%2Fpath.ext", "&{User: Passwd: Net:unix Addr:/tmp/mysql.sock DBName: Params:map[arg:/some/path.ext] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"username:password@protocol(address)/dbname?param=value", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:true ParseTime:false Strict:false}"}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{User:user Passwd: Net:unix Addr:/path/to/socket DBName:dbname Params:map[charset:utf8] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8mb4,utf8] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{User:user Passwd:password Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS: Timeout:30s ReadTimeout:1s WriteTimeout:1s Collation:224 AllowAllFiles:true AllowCleartextPasswords:false AllowOldPasswords:true ClientFoundRows:true ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{User:user Passwd:p@ss(word) Net:tcp Addr:[de:ad:be:ef::ca:fe]:80 DBName:dbname Params:map[] Loc:Local TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"/dbname", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"@/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"user:p@/ssword@/", "&{User:user Passwd:p@/ssword Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, + {"unix/?arg=%2Fsome%2Fpath.ext", "&{User: Passwd: Net:unix Addr:/tmp/mysql.sock DBName: Params:map[arg:/some/path.ext] Loc:UTC TLS: Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"}, } func TestDSNParser(t *testing.T) { From 4aa920ddb859e542ef720fcd7291fbadc8404a27 Mon Sep 17 00:00:00 2001 From: Badoet Endoet Date: Tue, 12 May 2015 12:05:02 +0800 Subject: [PATCH 05/10] TestMultiQuery discard additional OK response after Multi Statement Exec Calls --- AUTHORS | 1 + driver_test.go | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++ packets.go | 1 + 3 files changed, 83 insertions(+) diff --git a/AUTHORS b/AUTHORS index 3c61a1a87..986f017cb 100644 --- a/AUTHORS +++ b/AUTHORS @@ -39,6 +39,7 @@ Nicola Peduzzi Runrioter Wung Soroush Pour Stan Putrya +Stanley Gunawan Xiaobing Jiang Xiuming Chen diff --git a/driver_test.go b/driver_test.go index 319d34f91..a902fe323 100644 --- a/driver_test.go +++ b/driver_test.go @@ -76,6 +76,28 @@ type DBTest struct { db *sql.DB } +func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL-Server not running on %s", netAddr) + } + + dsn3 := dsn + "&multiStatements=true" + var db3 *sql.DB + if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation { + db3, err = sql.Open("mysql", dsn3) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + defer db3.Close() + } + + dbt3 := &DBTest{t, db3} + for _, test := range tests { + test(dbt3) + dbt3.db.Exec("DROP TABLE IF EXISTS test") + } +} + func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if !available { t.Skipf("MySQL server not running on %s", netAddr) @@ -99,8 +121,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { defer db2.Close() } + dsn3 := dsn + "&multiStatements=true" + var db3 *sql.DB + if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation { + db3, err = sql.Open("mysql", dsn3) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + defer db3.Close() + } + dbt := &DBTest{t, db} dbt2 := &DBTest{t, db2} + dbt3 := &DBTest{t, db3} for _, test := range tests { test(dbt) dbt.db.Exec("DROP TABLE IF EXISTS test") @@ -108,6 +141,10 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { test(dbt2) dbt2.db.Exec("DROP TABLE IF EXISTS test") } + if db3 != nil { + test(dbt3) + dbt3.db.Exec("DROP TABLE IF EXISTS test") + } } } @@ -237,6 +274,50 @@ func TestCRUD(t *testing.T) { }) } +func TestMultiQuery(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1, 1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("Expected 1 affected row, got %d", count) + } + + // Update + res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("Expected 1 affected row, got %d", count) + } + + // Read + var out int + rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") + if rows.Next() { + rows.Scan(&out) + if 5 != out { + dbt.Errorf("5 != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + + }) +} + func TestInt(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} diff --git a/packets.go b/packets.go index 1b13013cf..91e9a85b1 100644 --- a/packets.go +++ b/packets.go @@ -543,6 +543,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) + mc.discardMoreResultsIfExists() // warning count [2 bytes] if !mc.strict { From acb3ebdd872afb15b2c5824f23ae1781a284a427 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 20 Jan 2016 22:52:29 +0100 Subject: [PATCH 06/10] Fix driver tests --- driver_test.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/driver_test.go b/driver_test.go index a902fe323..cd2a9168e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -78,23 +78,23 @@ type DBTest struct { func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if !available { - t.Skipf("MySQL-Server not running on %s", netAddr) + t.Skipf("MySQL server not running on %s", netAddr) } - dsn3 := dsn + "&multiStatements=true" - var db3 *sql.DB - if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation { - db3, err = sql.Open("mysql", dsn3) + dsn += "&multiStatements=true" + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) if err != nil { - t.Fatalf("Error connecting: %s", err.Error()) + t.Fatalf("error connecting: %s", err.Error()) } - defer db3.Close() + defer db.Close() } - dbt3 := &DBTest{t, db3} + dbt := &DBTest{t, db} for _, test := range tests { - test(dbt3) - dbt3.db.Exec("DROP TABLE IF EXISTS test") + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") } } @@ -123,10 +123,10 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { dsn3 := dsn + "&multiStatements=true" var db3 *sql.DB - if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation { + if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { db3, err = sql.Open("mysql", dsn3) if err != nil { - t.Fatalf("Error connecting: %s", err.Error()) + t.Fatalf("error connecting: %s", err.Error()) } defer db3.Close() } @@ -286,7 +286,7 @@ func TestMultiQuery(t *testing.T) { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { - dbt.Fatalf("Expected 1 affected row, got %d", count) + dbt.Fatalf("expected 1 affected row, got %d", count) } // Update @@ -296,7 +296,7 @@ func TestMultiQuery(t *testing.T) { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { - dbt.Fatalf("Expected 1 affected row, got %d", count) + dbt.Fatalf("expected 1 affected row, got %d", count) } // Read From c1e44c429be3daae425c329724f2f4739c4742dc Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Wed, 20 Jan 2016 22:58:45 +0100 Subject: [PATCH 07/10] check discardMoreResultsIfExists err --- packets.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packets.go b/packets.go index 91e9a85b1..f58d12533 100644 --- a/packets.go +++ b/packets.go @@ -543,7 +543,9 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) - mc.discardMoreResultsIfExists() + if err := mc.discardMoreResultsIfExists(); err != nil { + return err + } // warning count [2 bytes] if !mc.strict { From 023343e2a99eb9bbe70a867202bacf462ae68844 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Thu, 21 Jan 2016 00:23:17 +0100 Subject: [PATCH 08/10] README: add multiStatements --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index f3190e8ab..da3ace49f 100644 --- a/README.md +++ b/README.md @@ -219,6 +219,16 @@ Note that this sets the location for time.Time values but does not change MySQL' Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. +##### `multiStatements` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. + ##### `parseTime` From 416bd115d629893dabcb41b95f42a5c37b9f216a Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sat, 23 Jan 2016 00:58:37 +0100 Subject: [PATCH 09/10] README: Add note on discarded results --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index da3ace49f..78a6bddf4 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,7 @@ Valid Values: true, false Default: false ``` -Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. +Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded. ##### `parseTime` From bba2f88dafaa2266c1ad8477e823dcb6253bd242 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sat, 30 Jan 2016 00:10:55 +0100 Subject: [PATCH 10/10] Rename discardMoreResultsIfExists --- packets.go | 8 ++++---- rows.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packets.go b/packets.go index f58d12533..fab128e46 100644 --- a/packets.go +++ b/packets.go @@ -543,7 +543,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) - if err := mc.discardMoreResultsIfExists(); err != nil { + if err := mc.discardResults(); err != nil { return err } @@ -666,7 +666,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardMoreResultsIfExists(); err != nil { + if err := rows.mc.discardResults(); err != nil { return err } rows.mc = nil @@ -1034,7 +1034,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { return mc.writePacket(data) } -func (mc *mysqlConn) discardMoreResultsIfExists() error { +func (mc *mysqlConn) discardResults() error { for mc.status&statusMoreResultsExists != 0 { resLen, err := mc.readResultSetHeaderPacket() if err != nil { @@ -1068,7 +1068,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardMoreResultsIfExists(); err != nil { + if err := rows.mc.discardResults(); err != nil { return err } rows.mc = nil diff --git a/rows.go b/rows.go index 9853f8323..c08255eee 100644 --- a/rows.go +++ b/rows.go @@ -66,7 +66,7 @@ func (rows *mysqlRows) Close() error { // Remove unread packets from stream err := mc.readUntilEOF() if err == nil { - if err = mc.discardMoreResultsIfExists(); err != nil { + if err = mc.discardResults(); err != nil { return err } }