Skip to content

Commit 232e28e

Browse files
committed
Add Multi-Results support
Fixes go-sql-driver#420
1 parent 2e00b5c commit 232e28e

File tree

6 files changed

+189
-35
lines changed

6 files changed

+189
-35
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Hanno Braun <mail at hannobraun.com>
2525
Henri Yandell <flamefew at gmail.com>
2626
Hirotaka Yamamoto <ymmt2005 at gmail.com>
2727
INADA Naoki <songofacandy at gmail.com>
28+
Jacek Szwec <szwec.jacek at gmail.com>
2829
James Harr <james.harr at gmail.com>
2930
Jian Zhen <zhenjl at gmail.com>
3031
Joshua Prunier <joshua.prunier at gmail.com>

connection.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
339339
return emptyRows{}, nil
340340
}
341341
// Columns
342-
rows.columns, err = mc.readColumns(resLen)
342+
rows.rs.columns, err = mc.readColumns(resLen)
343343
return rows, err
344344
}
345345
}
@@ -359,7 +359,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
359359
if err == nil {
360360
rows := new(textRows)
361361
rows.mc = mc
362-
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
362+
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
363363

364364
if resLen > 0 {
365365
// Columns

driver_18_test.go

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// +build go1.8
2+
3+
package mysql
4+
5+
import (
6+
"reflect"
7+
"testing"
8+
)
9+
10+
func TestMultiResultSet(t *testing.T) {
11+
runTests(t, dsn, func(dbt *DBTest) {
12+
type result struct {
13+
values [][]int
14+
columns []string
15+
}
16+
17+
expected := []result{
18+
{
19+
values: [][]int{{1, 2}, {3, 4}},
20+
columns: []string{"col1", "col2"},
21+
},
22+
{
23+
values: [][]int{{1, 2, 3}, {4, 5, 6}},
24+
columns: []string{"col1", "col2", "col3"},
25+
},
26+
}
27+
28+
query := `
29+
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
30+
SELECT 0 UNION SELECT 1; -- ignore this result set
31+
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`
32+
33+
rows := dbt.mustQuery(query)
34+
defer rows.Close()
35+
36+
var res1 result
37+
for rows.Next() {
38+
var res [2]int
39+
if err := rows.Scan(&res[0], &res[1]); err != nil {
40+
dbt.Fatal(err)
41+
}
42+
res1.values = append(res1.values, res[:])
43+
}
44+
45+
if rows.Next() {
46+
dbt.Error("unexpected row")
47+
}
48+
49+
cols, err := rows.Columns()
50+
if err != nil {
51+
dbt.Fatal(err)
52+
}
53+
res1.columns = cols
54+
55+
if !reflect.DeepEqual(expected[0], res1) {
56+
dbt.Error("want =", expected[0], "got =", res1)
57+
}
58+
59+
if !rows.NextResultSet() {
60+
dbt.Fatal("expected next result set")
61+
}
62+
63+
// ignoring one result set
64+
65+
if !rows.NextResultSet() {
66+
dbt.Fatal("expected next result set")
67+
}
68+
69+
var res2 result
70+
cols, err = rows.Columns()
71+
if err != nil {
72+
dbt.Fatal(err)
73+
}
74+
res2.columns = cols
75+
76+
for rows.Next() {
77+
var res [3]int
78+
if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
79+
dbt.Fatal(err)
80+
}
81+
res2.values = append(res2.values, res[:])
82+
}
83+
84+
if !reflect.DeepEqual(expected[1], res2) {
85+
dbt.Error("want =", expected[1], "got =", res2)
86+
}
87+
88+
if rows.Next() {
89+
dbt.Error("unexpected row")
90+
}
91+
92+
if rows.NextResultSet() {
93+
dbt.Error("unexpected next result set")
94+
}
95+
96+
if err := rows.Err(); err != nil {
97+
dbt.Error(err)
98+
}
99+
})
100+
}

packets.go

+23-21
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
231231
clientTransactions |
232232
clientLocalFiles |
233233
clientPluginAuth |
234+
clientMultiStatements |
234235
clientMultiResults |
235236
mc.flags&clientLongFlag
236237

@@ -698,6 +699,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
698699
func (rows *textRows) readRow(dest []driver.Value) error {
699700
mc := rows.mc
700701

702+
if rows.rs.done {
703+
return io.EOF
704+
}
705+
701706
data, err := mc.readPacket()
702707
if err != nil {
703708
return err
@@ -707,15 +712,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
707712
if data[0] == iEOF && len(data) == 5 {
708713
// server_status [2 bytes]
709714
rows.mc.status = readStatus(data[3:])
710-
err = rows.mc.discardResults()
711-
if err == nil {
712-
err = io.EOF
713-
} else {
714-
// connection unusable
715-
rows.mc.Close()
715+
rows.rs.done = true
716+
if !rows.HasNextResultSet() {
717+
rows.mc = nil
716718
}
717-
rows.mc = nil
718-
return err
719+
return io.EOF
719720
}
720721
if data[0] == iERR {
721722
rows.mc = nil
@@ -736,7 +737,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
736737
if !mc.parseTime {
737738
continue
738739
} else {
739-
switch rows.columns[i].fieldType {
740+
switch rows.rs.columns[i].fieldType {
740741
case fieldTypeTimestamp, fieldTypeDateTime,
741742
fieldTypeDate, fieldTypeNewDate:
742743
dest[i], err = parseDateTime(
@@ -1116,6 +1117,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11161117
// EOF Packet
11171118
if data[0] == iEOF && len(data) == 5 {
11181119
rows.mc.status = readStatus(data[3:])
1120+
rows.rs.done = true
11191121
err = rows.mc.discardResults()
11201122
if err == nil {
11211123
err = io.EOF
@@ -1145,14 +1147,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11451147
}
11461148

11471149
// Convert to byte-coded string
1148-
switch rows.columns[i].fieldType {
1150+
switch rows.rs.columns[i].fieldType {
11491151
case fieldTypeNULL:
11501152
dest[i] = nil
11511153
continue
11521154

11531155
// Numeric Types
11541156
case fieldTypeTiny:
1155-
if rows.columns[i].flags&flagUnsigned != 0 {
1157+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11561158
dest[i] = int64(data[pos])
11571159
} else {
11581160
dest[i] = int64(int8(data[pos]))
@@ -1161,7 +1163,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11611163
continue
11621164

11631165
case fieldTypeShort, fieldTypeYear:
1164-
if rows.columns[i].flags&flagUnsigned != 0 {
1166+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11651167
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
11661168
} else {
11671169
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
@@ -1170,7 +1172,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11701172
continue
11711173

11721174
case fieldTypeInt24, fieldTypeLong:
1173-
if rows.columns[i].flags&flagUnsigned != 0 {
1175+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11741176
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
11751177
} else {
11761178
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
@@ -1179,7 +1181,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11791181
continue
11801182

11811183
case fieldTypeLongLong:
1182-
if rows.columns[i].flags&flagUnsigned != 0 {
1184+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11831185
val := binary.LittleEndian.Uint64(data[pos : pos+8])
11841186
if val > math.MaxInt64 {
11851187
dest[i] = uint64ToString(val)
@@ -1233,37 +1235,37 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
12331235
case isNull:
12341236
dest[i] = nil
12351237
continue
1236-
case rows.columns[i].fieldType == fieldTypeTime:
1238+
case rows.rs.columns[i].fieldType == fieldTypeTime:
12371239
// database/sql does not support an equivalent to TIME, return a string
12381240
var dstlen uint8
1239-
switch decimals := rows.columns[i].decimals; decimals {
1241+
switch decimals := rows.rs.columns[i].decimals; decimals {
12401242
case 0x00, 0x1f:
12411243
dstlen = 8
12421244
case 1, 2, 3, 4, 5, 6:
12431245
dstlen = 8 + 1 + decimals
12441246
default:
12451247
return fmt.Errorf(
12461248
"protocol error, illegal decimals value %d",
1247-
rows.columns[i].decimals,
1249+
rows.rs.columns[i].decimals,
12481250
)
12491251
}
12501252
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
12511253
case rows.mc.parseTime:
12521254
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
12531255
default:
12541256
var dstlen uint8
1255-
if rows.columns[i].fieldType == fieldTypeDate {
1257+
if rows.rs.columns[i].fieldType == fieldTypeDate {
12561258
dstlen = 10
12571259
} else {
1258-
switch decimals := rows.columns[i].decimals; decimals {
1260+
switch decimals := rows.rs.columns[i].decimals; decimals {
12591261
case 0x00, 0x1f:
12601262
dstlen = 19
12611263
case 1, 2, 3, 4, 5, 6:
12621264
dstlen = 19 + 1 + decimals
12631265
default:
12641266
return fmt.Errorf(
12651267
"protocol error, illegal decimals value %d",
1266-
rows.columns[i].decimals,
1268+
rows.rs.columns[i].decimals,
12671269
)
12681270
}
12691271
}
@@ -1279,7 +1281,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
12791281

12801282
// Please report if this happens!
12811283
default:
1282-
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
1284+
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
12831285
}
12841286
}
12851287

rows.go

+60-9
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@ type mysqlField struct {
2121
decimals byte
2222
}
2323

24-
type mysqlRows struct {
25-
mc *mysqlConn
24+
type resultSet struct {
2625
columns []mysqlField
26+
done bool
27+
}
28+
29+
type mysqlRows struct {
30+
mc *mysqlConn
31+
rs resultSet
2732
}
2833

2934
type binaryRows struct {
@@ -37,24 +42,24 @@ type textRows struct {
3742
type emptyRows struct{}
3843

3944
func (rows *mysqlRows) Columns() []string {
40-
columns := make([]string, len(rows.columns))
45+
columns := make([]string, len(rows.rs.columns))
4146
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
4247
for i := range columns {
43-
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
44-
columns[i] = tableName + "." + rows.columns[i].name
48+
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
49+
columns[i] = tableName + "." + rows.rs.columns[i].name
4550
} else {
46-
columns[i] = rows.columns[i].name
51+
columns[i] = rows.rs.columns[i].name
4752
}
4853
}
4954
} else {
5055
for i := range columns {
51-
columns[i] = rows.columns[i].name
56+
columns[i] = rows.rs.columns[i].name
5257
}
5358
}
5459
return columns
5560
}
5661

57-
func (rows *mysqlRows) Close() error {
62+
func (rows *mysqlRows) Close() (err error) {
5863
mc := rows.mc
5964
if mc == nil {
6065
return nil
@@ -64,7 +69,9 @@ func (rows *mysqlRows) Close() error {
6469
}
6570

6671
// Remove unread packets from stream
67-
err := mc.readUntilEOF()
72+
if !rows.rs.done {
73+
err = mc.readUntilEOF()
74+
}
6875
if err == nil {
6976
if err = mc.discardResults(); err != nil {
7077
return err
@@ -99,6 +106,42 @@ func (rows *textRows) Next(dest []driver.Value) error {
99106
return io.EOF
100107
}
101108

109+
func (rows *textRows) HasNextResultSet() (b bool) {
110+
if rows.mc == nil {
111+
return false
112+
}
113+
return rows.mc.status&statusMoreResultsExists != 0
114+
}
115+
116+
func (rows *textRows) NextResultSet() error {
117+
if rows.mc == nil {
118+
return io.EOF
119+
}
120+
if rows.mc.netConn == nil {
121+
return ErrInvalidConn
122+
}
123+
124+
// Remove unread packets from stream
125+
if !rows.rs.done {
126+
if err := rows.mc.readUntilEOF(); err != nil {
127+
return err
128+
}
129+
}
130+
131+
if !rows.HasNextResultSet() {
132+
return io.EOF
133+
}
134+
rows.rs = resultSet{}
135+
136+
resLen, err := rows.mc.readResultSetHeaderPacket()
137+
if err != nil {
138+
return err
139+
}
140+
141+
rows.rs.columns, err = rows.mc.readColumns(resLen)
142+
return err
143+
}
144+
102145
func (rows emptyRows) Columns() []string {
103146
return nil
104147
}
@@ -110,3 +153,11 @@ func (rows emptyRows) Close() error {
110153
func (rows emptyRows) Next(dest []driver.Value) error {
111154
return io.EOF
112155
}
156+
157+
func (rows emptyRows) HasNextResultSet() bool {
158+
return false
159+
}
160+
161+
func (rows emptyRows) NextResultSet() error {
162+
return io.EOF
163+
}

0 commit comments

Comments
 (0)