Skip to content

Commit 8c632dd

Browse files
committed
Add Multi-Results support
Fixes go-sql-driver#420
1 parent 2e00b5c commit 8c632dd

File tree

6 files changed

+167
-36
lines changed

6 files changed

+167
-36
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Hanno Braun <mail at hannobraun.com>
2525
Henri Yandell <flamefew at gmail.com>
2626
Hirotaka Yamamoto <ymmt2005 at gmail.com>
2727
INADA Naoki <songofacandy at gmail.com>
28+
Jacek Szwec <szwec.jacek at gmail.com>
2829
James Harr <james.harr at gmail.com>
2930
Jian Zhen <zhenjl at gmail.com>
3031
Joshua Prunier <joshua.prunier at gmail.com>

connection.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,12 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
333333
if err == nil {
334334
rows := new(textRows)
335335
rows.mc = mc
336-
337336
if resLen == 0 {
338337
// no columns, no more data
339338
return emptyRows{}, nil
340339
}
341340
// Columns
342-
rows.columns, err = mc.readColumns(resLen)
341+
rows.rs.columns, err = mc.readColumns(resLen)
343342
return rows, err
344343
}
345344
}
@@ -359,7 +358,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
359358
if err == nil {
360359
rows := new(textRows)
361360
rows.mc = mc
362-
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
361+
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
363362

364363
if resLen > 0 {
365364
// Columns

driver_18_test.go

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// +build go1.8
2+
3+
package mysql
4+
5+
import (
6+
"reflect"
7+
"testing"
8+
)
9+
10+
func TestMultiResultSet(t *testing.T) {
11+
runTests(t, dsn, func(dbt *DBTest) {
12+
type result struct {
13+
values [][]int
14+
columns []string
15+
}
16+
17+
expected := []result{
18+
{
19+
values: [][]int{{1, 2}, {3, 4}},
20+
columns: []string{"col1", "col2"},
21+
},
22+
{
23+
values: [][]int{{1, 2, 3}, {4, 5, 6}},
24+
columns: []string{"col1", "col2", "col3"},
25+
},
26+
}
27+
28+
query := `
29+
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
30+
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6`
31+
32+
rows := dbt.mustQuery(query)
33+
defer rows.Close()
34+
35+
var res1 result
36+
for rows.Next() {
37+
var res [2]int
38+
if err := rows.Scan(&res[0], &res[1]); err != nil {
39+
dbt.Fatal(err)
40+
}
41+
res1.values = append(res1.values, res[:])
42+
}
43+
44+
if rows.Next() {
45+
dbt.Error("unexpected row")
46+
}
47+
48+
cols, err := rows.Columns()
49+
if err != nil {
50+
dbt.Fatal(err)
51+
}
52+
res1.columns = cols
53+
54+
if !reflect.DeepEqual(expected[0], res1) {
55+
dbt.Error("want =", expected[0], "got =", res1)
56+
}
57+
58+
if !rows.NextResultSet() {
59+
dbt.Fatal("expected next result set")
60+
}
61+
62+
var res2 result
63+
cols, err = rows.Columns()
64+
if err != nil {
65+
dbt.Fatal(err)
66+
}
67+
res2.columns = cols
68+
69+
for rows.Next() {
70+
var res [3]int
71+
if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
72+
dbt.Fatal(err)
73+
}
74+
res2.values = append(res2.values, res[:])
75+
}
76+
77+
if !reflect.DeepEqual(expected[1], res2) {
78+
dbt.Error("want =", expected[1], "got =", res2)
79+
}
80+
81+
if rows.Next() {
82+
dbt.Error("unexpected row")
83+
}
84+
85+
if rows.NextResultSet() {
86+
dbt.Error("unexpected next result set")
87+
}
88+
89+
if err := rows.Err(); err != nil {
90+
dbt.Error(err)
91+
}
92+
})
93+
}

packets.go

+22-21
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
231231
clientTransactions |
232232
clientLocalFiles |
233233
clientPluginAuth |
234+
clientMultiStatements |
234235
clientMultiResults |
235236
mc.flags&clientLongFlag
236237

@@ -698,6 +699,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
698699
func (rows *textRows) readRow(dest []driver.Value) error {
699700
mc := rows.mc
700701

702+
if rows.rs.done {
703+
return io.EOF
704+
}
705+
701706
data, err := mc.readPacket()
702707
if err != nil {
703708
return err
@@ -707,15 +712,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
707712
if data[0] == iEOF && len(data) == 5 {
708713
// server_status [2 bytes]
709714
rows.mc.status = readStatus(data[3:])
710-
err = rows.mc.discardResults()
711-
if err == nil {
712-
err = io.EOF
713-
} else {
714-
// connection unusable
715-
rows.mc.Close()
715+
rows.rs.done = true
716+
if !rows.HasNextResultSet() {
717+
rows.mc = nil
716718
}
717-
rows.mc = nil
718-
return err
719+
return io.EOF
719720
}
720721
if data[0] == iERR {
721722
rows.mc = nil
@@ -736,7 +737,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
736737
if !mc.parseTime {
737738
continue
738739
} else {
739-
switch rows.columns[i].fieldType {
740+
switch rows.rs.columns[i].fieldType {
740741
case fieldTypeTimestamp, fieldTypeDateTime,
741742
fieldTypeDate, fieldTypeNewDate:
742743
dest[i], err = parseDateTime(
@@ -1145,14 +1146,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11451146
}
11461147

11471148
// Convert to byte-coded string
1148-
switch rows.columns[i].fieldType {
1149+
switch rows.rs.columns[i].fieldType {
11491150
case fieldTypeNULL:
11501151
dest[i] = nil
11511152
continue
11521153

11531154
// Numeric Types
11541155
case fieldTypeTiny:
1155-
if rows.columns[i].flags&flagUnsigned != 0 {
1156+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11561157
dest[i] = int64(data[pos])
11571158
} else {
11581159
dest[i] = int64(int8(data[pos]))
@@ -1161,7 +1162,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11611162
continue
11621163

11631164
case fieldTypeShort, fieldTypeYear:
1164-
if rows.columns[i].flags&flagUnsigned != 0 {
1165+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11651166
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
11661167
} else {
11671168
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
@@ -1170,7 +1171,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11701171
continue
11711172

11721173
case fieldTypeInt24, fieldTypeLong:
1173-
if rows.columns[i].flags&flagUnsigned != 0 {
1174+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11741175
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
11751176
} else {
11761177
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
@@ -1179,7 +1180,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11791180
continue
11801181

11811182
case fieldTypeLongLong:
1182-
if rows.columns[i].flags&flagUnsigned != 0 {
1183+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11831184
val := binary.LittleEndian.Uint64(data[pos : pos+8])
11841185
if val > math.MaxInt64 {
11851186
dest[i] = uint64ToString(val)
@@ -1233,37 +1234,37 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
12331234
case isNull:
12341235
dest[i] = nil
12351236
continue
1236-
case rows.columns[i].fieldType == fieldTypeTime:
1237+
case rows.rs.columns[i].fieldType == fieldTypeTime:
12371238
// database/sql does not support an equivalent to TIME, return a string
12381239
var dstlen uint8
1239-
switch decimals := rows.columns[i].decimals; decimals {
1240+
switch decimals := rows.rs.columns[i].decimals; decimals {
12401241
case 0x00, 0x1f:
12411242
dstlen = 8
12421243
case 1, 2, 3, 4, 5, 6:
12431244
dstlen = 8 + 1 + decimals
12441245
default:
12451246
return fmt.Errorf(
12461247
"protocol error, illegal decimals value %d",
1247-
rows.columns[i].decimals,
1248+
rows.rs.columns[i].decimals,
12481249
)
12491250
}
12501251
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
12511252
case rows.mc.parseTime:
12521253
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
12531254
default:
12541255
var dstlen uint8
1255-
if rows.columns[i].fieldType == fieldTypeDate {
1256+
if rows.rs.columns[i].fieldType == fieldTypeDate {
12561257
dstlen = 10
12571258
} else {
1258-
switch decimals := rows.columns[i].decimals; decimals {
1259+
switch decimals := rows.rs.columns[i].decimals; decimals {
12591260
case 0x00, 0x1f:
12601261
dstlen = 19
12611262
case 1, 2, 3, 4, 5, 6:
12621263
dstlen = 19 + 1 + decimals
12631264
default:
12641265
return fmt.Errorf(
12651266
"protocol error, illegal decimals value %d",
1266-
rows.columns[i].decimals,
1267+
rows.rs.columns[i].decimals,
12671268
)
12681269
}
12691270
}
@@ -1279,7 +1280,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
12791280

12801281
// Please report if this happens!
12811282
default:
1282-
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
1283+
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
12831284
}
12841285
}
12851286

rows.go

+46-9
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@ type mysqlField struct {
2121
decimals byte
2222
}
2323

24-
type mysqlRows struct {
25-
mc *mysqlConn
24+
type resultSet struct {
2625
columns []mysqlField
26+
done bool
27+
}
28+
29+
type mysqlRows struct {
30+
mc *mysqlConn
31+
rs resultSet
2732
}
2833

2934
type binaryRows struct {
@@ -37,24 +42,24 @@ type textRows struct {
3742
type emptyRows struct{}
3843

3944
func (rows *mysqlRows) Columns() []string {
40-
columns := make([]string, len(rows.columns))
45+
columns := make([]string, len(rows.rs.columns))
4146
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
4247
for i := range columns {
43-
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
44-
columns[i] = tableName + "." + rows.columns[i].name
48+
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
49+
columns[i] = tableName + "." + rows.rs.columns[i].name
4550
} else {
46-
columns[i] = rows.columns[i].name
51+
columns[i] = rows.rs.columns[i].name
4752
}
4853
}
4954
} else {
5055
for i := range columns {
51-
columns[i] = rows.columns[i].name
56+
columns[i] = rows.rs.columns[i].name
5257
}
5358
}
5459
return columns
5560
}
5661

57-
func (rows *mysqlRows) Close() error {
62+
func (rows *mysqlRows) Close() (err error) {
5863
mc := rows.mc
5964
if mc == nil {
6065
return nil
@@ -64,7 +69,9 @@ func (rows *mysqlRows) Close() error {
6469
}
6570

6671
// Remove unread packets from stream
67-
err := mc.readUntilEOF()
72+
if !rows.rs.done {
73+
err = mc.readUntilEOF()
74+
}
6875
if err == nil {
6976
if err = mc.discardResults(); err != nil {
7077
return err
@@ -99,6 +106,28 @@ func (rows *textRows) Next(dest []driver.Value) error {
99106
return io.EOF
100107
}
101108

109+
func (rows *textRows) HasNextResultSet() (b bool) {
110+
if rows.mc == nil {
111+
return false
112+
}
113+
return rows.mc.status&statusMoreResultsExists != 0
114+
}
115+
116+
func (rows *textRows) NextResultSet() error {
117+
if !rows.HasNextResultSet() {
118+
return io.EOF
119+
}
120+
rows.rs = resultSet{}
121+
122+
resLen, err := rows.mc.readResultSetHeaderPacket()
123+
if err != nil {
124+
return err
125+
}
126+
127+
rows.rs.columns, err = rows.mc.readColumns(resLen)
128+
return err
129+
}
130+
102131
func (rows emptyRows) Columns() []string {
103132
return nil
104133
}
@@ -110,3 +139,11 @@ func (rows emptyRows) Close() error {
110139
func (rows emptyRows) Next(dest []driver.Value) error {
111140
return io.EOF
112141
}
142+
143+
func (rows emptyRows) HasNextResultSet() bool {
144+
return false
145+
}
146+
147+
func (rows emptyRows) NextResultSet() error {
148+
return io.EOF
149+
}

statement.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
110110
// Columns
111111
// If not cached, read them and cache them
112112
if stmt.columns == nil {
113-
rows.columns, err = mc.readColumns(resLen)
114-
stmt.columns = rows.columns
113+
rows.rs.columns, err = mc.readColumns(resLen)
114+
stmt.columns = rows.rs.columns
115115
} else {
116-
rows.columns = stmt.columns
116+
rows.rs.columns = stmt.columns
117117
err = mc.readUntilEOF()
118118
}
119119
}

0 commit comments

Comments
 (0)