Skip to content

Commit b4db83c

Browse files
committed
Merge pull request #411 from go-sql-driver/multistmt
Multistatements and multi results
2 parents 7c7f556 + bba2f88 commit b4db83c

File tree

8 files changed

+174
-17
lines changed

8 files changed

+174
-17
lines changed

Diff for: AUTHORS

+2
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@ Julien Schmidt <go-sql-driver at julienschmidt.com>
3131
Kamil Dziedzic <kamil at klecza.pl>
3232
Kevin Malachowski <kevin at chowski.com>
3333
Leonardo YongUk Kim <dalinaum at gmail.com>
34+
Luca Looz <luca.looz92 at gmail.com>
3435
Lucas Liu <extrafliu at gmail.com>
3536
Luke Scott <luke at webconnex.com>
3637
Michael Woolnough <michael.woolnough at gmail.com>
3738
Nicola Peduzzi <thenikso at gmail.com>
3839
Runrioter Wung <runrioter at gmail.com>
3940
Soroush Pour <me at soroushjp.com>
4041
Stan Putrya <root.vagner at gmail.com>
42+
Stanley Gunawan <gunawan.stanley at gmail.com>
4143
Xiaobing Jiang <s7v7nislands at gmail.com>
4244
Xiuming Chen <cc at cxm.cc>
4345

Diff for: README.md

+10
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,16 @@ Note that this sets the location for time.Time values but does not change MySQL'
219219

220220
Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
221221

222+
##### `multiStatements`
223+
224+
```
225+
Type: bool
226+
Valid Values: true, false
227+
Default: false
228+
```
229+
230+
Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded.
231+
222232

223233
##### `parseTime`
224234

Diff for: driver_test.go

+81
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,28 @@ type DBTest struct {
7676
db *sql.DB
7777
}
7878

79+
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
80+
if !available {
81+
t.Skipf("MySQL server not running on %s", netAddr)
82+
}
83+
84+
dsn += "&multiStatements=true"
85+
var db *sql.DB
86+
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
87+
db, err = sql.Open("mysql", dsn)
88+
if err != nil {
89+
t.Fatalf("error connecting: %s", err.Error())
90+
}
91+
defer db.Close()
92+
}
93+
94+
dbt := &DBTest{t, db}
95+
for _, test := range tests {
96+
test(dbt)
97+
dbt.db.Exec("DROP TABLE IF EXISTS test")
98+
}
99+
}
100+
79101
func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
80102
if !available {
81103
t.Skipf("MySQL server not running on %s", netAddr)
@@ -99,15 +121,30 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
99121
defer db2.Close()
100122
}
101123

124+
dsn3 := dsn + "&multiStatements=true"
125+
var db3 *sql.DB
126+
if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
127+
db3, err = sql.Open("mysql", dsn3)
128+
if err != nil {
129+
t.Fatalf("error connecting: %s", err.Error())
130+
}
131+
defer db3.Close()
132+
}
133+
102134
dbt := &DBTest{t, db}
103135
dbt2 := &DBTest{t, db2}
136+
dbt3 := &DBTest{t, db3}
104137
for _, test := range tests {
105138
test(dbt)
106139
dbt.db.Exec("DROP TABLE IF EXISTS test")
107140
if db2 != nil {
108141
test(dbt2)
109142
dbt2.db.Exec("DROP TABLE IF EXISTS test")
110143
}
144+
if db3 != nil {
145+
test(dbt3)
146+
dbt3.db.Exec("DROP TABLE IF EXISTS test")
147+
}
111148
}
112149
}
113150

@@ -237,6 +274,50 @@ func TestCRUD(t *testing.T) {
237274
})
238275
}
239276

277+
func TestMultiQuery(t *testing.T) {
278+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
279+
// Create Table
280+
dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")
281+
282+
// Create Data
283+
res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
284+
count, err := res.RowsAffected()
285+
if err != nil {
286+
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
287+
}
288+
if count != 1 {
289+
dbt.Fatalf("expected 1 affected row, got %d", count)
290+
}
291+
292+
// Update
293+
res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;")
294+
count, err = res.RowsAffected()
295+
if err != nil {
296+
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
297+
}
298+
if count != 1 {
299+
dbt.Fatalf("expected 1 affected row, got %d", count)
300+
}
301+
302+
// Read
303+
var out int
304+
rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
305+
if rows.Next() {
306+
rows.Scan(&out)
307+
if 5 != out {
308+
dbt.Errorf("5 != %t", out)
309+
}
310+
311+
if rows.Next() {
312+
dbt.Error("unexpected data")
313+
}
314+
} else {
315+
dbt.Error("no data")
316+
}
317+
318+
})
319+
}
320+
240321
func TestInt(t *testing.T) {
241322
runTests(t, dsn, func(dbt *DBTest) {
242323
types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}

Diff for: dsn.go

+9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type Config struct {
4646
ClientFoundRows bool // Return number of matching rows instead of rows changed
4747
ColumnsWithAlias bool // Prepend table alias to column names
4848
InterpolateParams bool // Interpolate placeholders into query string
49+
MultiStatements bool // Allow multiple statements in one query
4950
ParseTime bool // Parse time values to time.Time
5051
Strict bool // Return warnings as errors
5152
}
@@ -235,6 +236,14 @@ func parseDSNParams(cfg *Config, params string) (err error) {
235236
return
236237
}
237238

239+
// multiple statements in one query
240+
case "multiStatements":
241+
var isBool bool
242+
cfg.MultiStatements, isBool = readBool(value)
243+
if !isBool {
244+
return errors.New("invalid bool value: " + value)
245+
}
246+
238247
// time.Time parsing
239248
case "parseTime":
240249
var isBool bool

Diff for: dsn_test.go

+14-13
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,20 @@ var testDSNs = []struct {
1919
in string
2020
out string
2121
}{
22-
{"username:password@protocol(address)/dbname?param=value", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
23-
{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false ParseTime:false Strict:false}"},
24-
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{User:user Passwd: Net:unix Addr:/path/to/socket DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
25-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
26-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8mb4,utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
27-
{"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{User:user Passwd:password Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:30s ReadTimeout:1s WriteTimeout:1s Collation:224 AllowAllFiles:true AllowCleartextPasswords:false AllowOldPasswords:true ClientFoundRows:true ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
28-
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{User:user Passwd:p@ss(word) Net:tcp Addr:[de:ad:be:ef::ca:fe]:80 DBName:dbname Params:map[] Loc:Local TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
29-
{"/dbname", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
30-
{"@/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
31-
{"/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
32-
{"", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
33-
{"user:p@/ssword@/", "&{User:user Passwd:p@/ssword Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
34-
{"unix/?arg=%2Fsome%2Fpath.ext", "&{User: Passwd: Net:unix Addr:/tmp/mysql.sock DBName: Params:map[arg:/some/path.ext] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"},
22+
{"username:password@protocol(address)/dbname?param=value", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
23+
{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
24+
{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:true ParseTime:false Strict:false}"},
25+
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{User:user Passwd: Net:unix Addr:/path/to/socket DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
26+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
27+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8mb4,utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
28+
{"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{User:user Passwd:password Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:30s ReadTimeout:1s WriteTimeout:1s Collation:224 AllowAllFiles:true AllowCleartextPasswords:false AllowOldPasswords:true ClientFoundRows:true ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
29+
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{User:user Passwd:p@ss(word) Net:tcp Addr:[de:ad:be:ef::ca:fe]:80 DBName:dbname Params:map[] Loc:Local TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
30+
{"/dbname", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
31+
{"@/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
32+
{"/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
33+
{"", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
34+
{"user:p@/ssword@/", "&{User:user Passwd:p@/ssword Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
35+
{"unix/?arg=%2Fsome%2Fpath.ext", "&{User: Passwd: Net:unix Addr:/tmp/mysql.sock DBName: Params:map[arg:/some/path.ext] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
3536
}
3637

3738
func TestDSNParser(t *testing.T) {

Diff for: packets.go

+50-2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
224224
clientTransactions |
225225
clientLocalFiles |
226226
clientPluginAuth |
227+
clientMultiResults |
227228
mc.flags&clientLongFlag
228229

229230
if mc.cfg.ClientFoundRows {
@@ -235,6 +236,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
235236
clientFlags |= clientSSL
236237
}
237238

239+
if mc.cfg.MultiStatements {
240+
clientFlags |= clientMultiStatements
241+
}
242+
238243
// User Password
239244
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
240245

@@ -519,6 +524,10 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
519524
}
520525
}
521526

527+
func readStatus(b []byte) statusFlag {
528+
return statusFlag(b[0]) | statusFlag(b[1])<<8
529+
}
530+
522531
// Ok Packet
523532
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
524533
func (mc *mysqlConn) handleOkPacket(data []byte) error {
@@ -533,7 +542,10 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
533542
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
534543

535544
// server_status [2 bytes]
536-
mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
545+
mc.status = readStatus(data[1+n+m : 1+n+m+2])
546+
if err := mc.discardResults(); err != nil {
547+
return err
548+
}
537549

538550
// warning count [2 bytes]
539551
if !mc.strict {
@@ -652,6 +664,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
652664

653665
// EOF Packet
654666
if data[0] == iEOF && len(data) == 5 {
667+
// server_status [2 bytes]
668+
rows.mc.status = readStatus(data[3:])
669+
if err := rows.mc.discardResults(); err != nil {
670+
return err
671+
}
655672
rows.mc = nil
656673
return io.EOF
657674
}
@@ -709,6 +726,10 @@ func (mc *mysqlConn) readUntilEOF() error {
709726
if err == nil && data[0] != iEOF {
710727
continue
711728
}
729+
if err == nil && data[0] == iEOF && len(data) == 5 {
730+
mc.status = readStatus(data[3:])
731+
}
732+
712733
return err // Err or EOF
713734
}
714735
}
@@ -1013,6 +1034,28 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
10131034
return mc.writePacket(data)
10141035
}
10151036

1037+
func (mc *mysqlConn) discardResults() error {
1038+
for mc.status&statusMoreResultsExists != 0 {
1039+
resLen, err := mc.readResultSetHeaderPacket()
1040+
if err != nil {
1041+
return err
1042+
}
1043+
if resLen > 0 {
1044+
// columns
1045+
if err := mc.readUntilEOF(); err != nil {
1046+
return err
1047+
}
1048+
// rows
1049+
if err := mc.readUntilEOF(); err != nil {
1050+
return err
1051+
}
1052+
} else {
1053+
mc.status &^= statusMoreResultsExists
1054+
}
1055+
}
1056+
return nil
1057+
}
1058+
10161059
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
10171060
func (rows *binaryRows) readRow(dest []driver.Value) error {
10181061
data, err := rows.mc.readPacket()
@@ -1022,11 +1065,16 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
10221065

10231066
// packet indicator [1 byte]
10241067
if data[0] != iOK {
1025-
rows.mc = nil
10261068
// EOF Packet
10271069
if data[0] == iEOF && len(data) == 5 {
1070+
rows.mc.status = readStatus(data[3:])
1071+
if err := rows.mc.discardResults(); err != nil {
1072+
return err
1073+
}
1074+
rows.mc = nil
10281075
return io.EOF
10291076
}
1077+
rows.mc = nil
10301078

10311079
// Error otherwise
10321080
return rows.mc.handleErrorPacket(data)

Diff for: rows.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type emptyRows struct{}
3838

3939
func (rows *mysqlRows) Columns() []string {
4040
columns := make([]string, len(rows.columns))
41-
if rows.mc.cfg.ColumnsWithAlias {
41+
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
4242
for i := range columns {
4343
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
4444
columns[i] = tableName + "." + rows.columns[i].name
@@ -65,6 +65,12 @@ func (rows *mysqlRows) Close() error {
6565

6666
// Remove unread packets from stream
6767
err := mc.readUntilEOF()
68+
if err == nil {
69+
if err = mc.discardResults(); err != nil {
70+
return err
71+
}
72+
}
73+
6874
rows.mc = nil
6975
return err
7076
}

Diff for: statement.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
101101
}
102102

103103
rows := new(binaryRows)
104-
rows.mc = mc
105104

106105
if resLen > 0 {
106+
rows.mc = mc
107107
// Columns
108108
// If not cached, read them and cache them
109109
if stmt.columns == nil {

0 commit comments

Comments
 (0)