Skip to content

Commit 0ac9483

Browse files
committed
Multi-Results improvements
- support for binary protocol - support statements returning no results - remove emptyRows
1 parent c823aa0 commit 0ac9483

File tree

5 files changed

+219
-110
lines changed

5 files changed

+219
-110
lines changed

connection.go

+22-8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"database/sql/driver"
13+
"io"
1314
"net"
1415
"strconv"
1516
"strings"
@@ -289,22 +290,29 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
289290
// Internal function to execute commands
290291
func (mc *mysqlConn) exec(query string) error {
291292
// Send command
292-
err := mc.writeCommandPacketStr(comQuery, query)
293-
if err != nil {
293+
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
294294
return err
295295
}
296296

297297
// Read Result
298298
resLen, err := mc.readResultSetHeaderPacket()
299-
if err == nil && resLen > 0 {
300-
if err = mc.readUntilEOF(); err != nil {
299+
if err != nil {
300+
return err
301+
}
302+
303+
if resLen > 0 {
304+
// columns
305+
if err := mc.readUntilEOF(); err != nil {
301306
return err
302307
}
303308

304-
err = mc.readUntilEOF()
309+
// rows
310+
if err := mc.readUntilEOF(); err != nil {
311+
return err
312+
}
305313
}
306314

307-
return err
315+
return mc.discardResults()
308316
}
309317

310318
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
@@ -335,8 +343,14 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
335343
rows.mc = mc
336344

337345
if resLen == 0 {
338-
// no columns, no more data
339-
return emptyRows{}, nil
346+
rows.rs.done = true
347+
348+
switch err := rows.NextResultSet(); err {
349+
case nil, io.EOF:
350+
return rows, nil
351+
default:
352+
return nil, err
353+
}
340354
}
341355
// Columns
342356
rows.rs.columns, err = mc.readColumns(resLen)

driver_go18_test.go

+95-26
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,28 @@
33
package mysql
44

55
import (
6+
"database/sql"
7+
"fmt"
68
"reflect"
79
"testing"
810
)
911

1012
func TestMultiResultSet(t *testing.T) {
11-
runTests(t, dsn, func(dbt *DBTest) {
12-
type result struct {
13-
values [][]int
14-
columns []string
15-
}
16-
13+
type result struct {
14+
values [][]int
15+
columns []string
16+
}
17+
18+
// checkRows is a helper test function to validate rows containing 3 result
19+
// sets with specific values and columns. The basic query would look like this:
20+
//
21+
// SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
22+
// SELECT 0 UNION SELECT 1;
23+
// SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
24+
//
25+
// to distinguish test cases the first string argument is put in front of
26+
// every error or fatal message.
27+
checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) {
1728
expected := []result{
1829
{
1930
values: [][]int{{1, 2}, {3, 4}},
@@ -25,14 +36,6 @@ func TestMultiResultSet(t *testing.T) {
2536
},
2637
}
2738

28-
query := `
29-
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
30-
SELECT 0 UNION SELECT 1; -- ignore this result set
31-
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`
32-
33-
rows := dbt.mustQuery(query)
34-
defer rows.Close()
35-
3639
var res1 result
3740
for rows.Next() {
3841
var res [2]int
@@ -42,49 +45,115 @@ SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`
4245
res1.values = append(res1.values, res[:])
4346
}
4447

45-
if rows.Next() {
46-
dbt.Error("unexpected row")
47-
}
48-
4948
cols, err := rows.Columns()
5049
if err != nil {
51-
dbt.Fatal(err)
50+
dbt.Fatal(desc, err)
5251
}
5352
res1.columns = cols
5453

5554
if !reflect.DeepEqual(expected[0], res1) {
56-
dbt.Error("want =", expected[0], "got =", res1)
55+
dbt.Error(desc, "want =", expected[0], "got =", res1)
5756
}
5857

5958
if !rows.NextResultSet() {
60-
dbt.Fatal("expected next result set")
59+
dbt.Fatal(desc, "expected next result set")
6160
}
6261

6362
// ignoring one result set
6463

6564
if !rows.NextResultSet() {
66-
dbt.Fatal("expected next result set")
65+
dbt.Fatal(desc, "expected next result set")
6766
}
6867

6968
var res2 result
7069
cols, err = rows.Columns()
7170
if err != nil {
72-
dbt.Fatal(err)
71+
dbt.Fatal(desc, err)
7372
}
7473
res2.columns = cols
7574

7675
for rows.Next() {
7776
var res [3]int
7877
if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
79-
dbt.Fatal(err)
78+
dbt.Fatal(desc, err)
8079
}
8180
res2.values = append(res2.values, res[:])
8281
}
8382

8483
if !reflect.DeepEqual(expected[1], res2) {
85-
dbt.Error("want =", expected[1], "got =", res2)
84+
dbt.Error(desc, "want =", expected[1], "got =", res2)
85+
}
86+
87+
if rows.NextResultSet() {
88+
dbt.Error(desc, "unexpected next result set")
89+
}
90+
91+
if err := rows.Err(); err != nil {
92+
dbt.Error(desc, err)
93+
}
94+
}
95+
96+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
97+
rows := dbt.mustQuery(`DO 1;
98+
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
99+
DO 1;
100+
SELECT 0 UNION SELECT 1;
101+
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`)
102+
defer rows.Close()
103+
checkRows("query: ", rows, dbt)
104+
})
105+
106+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
107+
queries := []string{
108+
`
109+
DROP PROCEDURE IF EXISTS test_mrss;
110+
CREATE PROCEDURE test_mrss()
111+
BEGIN
112+
DO 1;
113+
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
114+
DO 1;
115+
SELECT 0 UNION SELECT 1;
116+
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
117+
END
118+
`,
119+
`
120+
DROP PROCEDURE IF EXISTS test_mrss;
121+
CREATE PROCEDURE test_mrss()
122+
BEGIN
123+
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
124+
SELECT 0 UNION SELECT 1;
125+
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
126+
END
127+
`,
86128
}
87129

130+
defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss")
131+
132+
for i, query := range queries {
133+
dbt.mustExec(query)
134+
135+
stmt, err := dbt.db.Prepare("CALL test_mrss()")
136+
if err != nil {
137+
dbt.Fatalf("%v (i=%d)", err, i)
138+
}
139+
defer stmt.Close()
140+
141+
for j := 0; j < 2; j++ {
142+
rows, err := stmt.Query()
143+
if err != nil {
144+
dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j)
145+
}
146+
checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt)
147+
}
148+
}
149+
})
150+
}
151+
152+
func TestMultiResultSetNoSelect(t *testing.T) {
153+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
154+
rows := dbt.mustQuery("DO 1; DO 2;")
155+
defer rows.Close()
156+
88157
if rows.Next() {
89158
dbt.Error("unexpected row")
90159
}
@@ -94,7 +163,7 @@ SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`
94163
}
95164

96165
if err := rows.Err(); err != nil {
97-
dbt.Error(err)
166+
dbt.Error("expected nil; got ", err)
98167
}
99168
})
100169
}

packets.go

+5-13
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
231231
clientTransactions |
232232
clientLocalFiles |
233233
clientPluginAuth |
234-
clientMultiStatements |
235234
clientMultiResults |
236235
mc.flags&clientLongFlag
237236

@@ -585,8 +584,8 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
585584

586585
// server_status [2 bytes]
587586
mc.status = readStatus(data[1+n+m : 1+n+m+2])
588-
if err := mc.discardResults(); err != nil {
589-
return err
587+
if mc.status&statusMoreResultsExists != 0 {
588+
return nil
590589
}
591590

592591
// warning count [2 bytes]
@@ -1098,8 +1097,6 @@ func (mc *mysqlConn) discardResults() error {
10981097
if err := mc.readUntilEOF(); err != nil {
10991098
return err
11001099
}
1101-
} else {
1102-
mc.status &^= statusMoreResultsExists
11031100
}
11041101
}
11051102
return nil
@@ -1118,15 +1115,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11181115
if data[0] == iEOF && len(data) == 5 {
11191116
rows.mc.status = readStatus(data[3:])
11201117
rows.rs.done = true
1121-
err = rows.mc.discardResults()
1122-
if err == nil {
1123-
err = io.EOF
1124-
} else {
1125-
// connection unusable
1126-
rows.mc.Close()
1118+
if !rows.HasNextResultSet() {
1119+
rows.mc = nil
11271120
}
1128-
rows.mc = nil
1129-
return err
1121+
return io.EOF
11301122
}
11311123
rows.mc = nil
11321124

0 commit comments

Comments
 (0)