Skip to content

Commit d707776

Browse files
committed
Add Multi-Results support
Fixes go-sql-driver#420
1 parent 2e00b5c commit d707776

File tree

6 files changed

+175
-39
lines changed

6 files changed

+175
-39
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Hanno Braun <mail at hannobraun.com>
2525
Henri Yandell <flamefew at gmail.com>
2626
Hirotaka Yamamoto <ymmt2005 at gmail.com>
2727
INADA Naoki <songofacandy at gmail.com>
28+
Jacek Szwec <[email protected]>
2829
James Harr <james.harr at gmail.com>
2930
Jian Zhen <zhenjl at gmail.com>
3031
Joshua Prunier <joshua.prunier at gmail.com>

connection.go

+4-7
Original file line numberDiff line numberDiff line change
@@ -331,15 +331,13 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
331331
var resLen int
332332
resLen, err = mc.readResultSetHeaderPacket()
333333
if err == nil {
334-
rows := new(textRows)
335-
rows.mc = mc
336-
334+
rows := newTextRows(mc)
337335
if resLen == 0 {
338336
// no columns, no more data
339337
return emptyRows{}, nil
340338
}
341339
// Columns
342-
rows.columns, err = mc.readColumns(resLen)
340+
rows.rs.columns, err = mc.readColumns(resLen)
343341
return rows, err
344342
}
345343
}
@@ -357,9 +355,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
357355
// Read Result
358356
resLen, err := mc.readResultSetHeaderPacket()
359357
if err == nil {
360-
rows := new(textRows)
361-
rows.mc = mc
362-
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
358+
rows := newTextRows(mc)
359+
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
363360

364361
if resLen > 0 {
365362
// Columns

driver_go18_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// +build go1.8
2+
3+
package mysql
4+
5+
import (
6+
"reflect"
7+
"testing"
8+
)
9+
10+
func TestMultiResultSet(t *testing.T) {
11+
runTests(t, dsn, func(dbt *DBTest) {
12+
type result struct {
13+
values [][]int
14+
columns []string
15+
}
16+
17+
expected := []result{
18+
{
19+
values: [][]int{{1, 2}, {3, 4}},
20+
columns: []string{"col1", "col2"},
21+
},
22+
{
23+
values: [][]int{{1, 2, 3}, {4, 5, 6}},
24+
columns: []string{"col1", "col2", "col3"},
25+
},
26+
}
27+
28+
query := `
29+
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
30+
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6`
31+
32+
rows := dbt.mustQuery(query)
33+
defer rows.Close()
34+
35+
var res1 result
36+
for rows.Next() {
37+
var res [2]int
38+
if err := rows.Scan(&res[0], &res[1]); err != nil {
39+
dbt.Fatal(err)
40+
}
41+
res1.values = append(res1.values, res[:])
42+
}
43+
if rows.Next() {
44+
dbt.Error("unexpected row")
45+
}
46+
cols, err := rows.Columns()
47+
if err != nil {
48+
dbt.Fatal(err)
49+
}
50+
res1.columns = cols
51+
if !reflect.DeepEqual(expected[0], res1) {
52+
dbt.Error("want =", expected[0], "got =", res1)
53+
}
54+
55+
if !rows.NextResultSet() {
56+
dbt.Fatal("expected next result set")
57+
}
58+
59+
var res2 result
60+
cols, err = rows.Columns()
61+
if err != nil {
62+
dbt.Fatal(err)
63+
}
64+
res2.columns = cols
65+
66+
for rows.Next() {
67+
var res [3]int
68+
if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
69+
dbt.Fatal(err)
70+
}
71+
res2.values = append(res2.values, res[:])
72+
}
73+
74+
if !reflect.DeepEqual(expected[1], res2) {
75+
dbt.Error("want =", expected[1], "got =", res2)
76+
}
77+
78+
if rows.Next() {
79+
dbt.Error("unexpected row")
80+
}
81+
82+
if rows.NextResultSet() {
83+
dbt.Error("unexpected next result set")
84+
}
85+
86+
if err := rows.Err(); err != nil {
87+
dbt.Error(err)
88+
}
89+
})
90+
}

packets.go

+22-21
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
231231
clientTransactions |
232232
clientLocalFiles |
233233
clientPluginAuth |
234+
clientMultiStatements |
234235
clientMultiResults |
235236
mc.flags&clientLongFlag
236237

@@ -698,6 +699,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
698699
func (rows *textRows) readRow(dest []driver.Value) error {
699700
mc := rows.mc
700701

702+
if rows.rs.done {
703+
return io.EOF
704+
}
705+
701706
data, err := mc.readPacket()
702707
if err != nil {
703708
return err
@@ -707,15 +712,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
707712
if data[0] == iEOF && len(data) == 5 {
708713
// server_status [2 bytes]
709714
rows.mc.status = readStatus(data[3:])
710-
err = rows.mc.discardResults()
711-
if err == nil {
712-
err = io.EOF
713-
} else {
714-
// connection unusable
715-
rows.mc.Close()
715+
rows.rs.done = true
716+
if !rows.HasNextResultSet() {
717+
rows.mc = nil
716718
}
717-
rows.mc = nil
718-
return err
719+
return io.EOF
719720
}
720721
if data[0] == iERR {
721722
rows.mc = nil
@@ -736,7 +737,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
736737
if !mc.parseTime {
737738
continue
738739
} else {
739-
switch rows.columns[i].fieldType {
740+
switch rows.rs.columns[i].fieldType {
740741
case fieldTypeTimestamp, fieldTypeDateTime,
741742
fieldTypeDate, fieldTypeNewDate:
742743
dest[i], err = parseDateTime(
@@ -1145,14 +1146,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11451146
}
11461147

11471148
// Convert to byte-coded string
1148-
switch rows.columns[i].fieldType {
1149+
switch rows.rs.columns[i].fieldType {
11491150
case fieldTypeNULL:
11501151
dest[i] = nil
11511152
continue
11521153

11531154
// Numeric Types
11541155
case fieldTypeTiny:
1155-
if rows.columns[i].flags&flagUnsigned != 0 {
1156+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11561157
dest[i] = int64(data[pos])
11571158
} else {
11581159
dest[i] = int64(int8(data[pos]))
@@ -1161,7 +1162,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11611162
continue
11621163

11631164
case fieldTypeShort, fieldTypeYear:
1164-
if rows.columns[i].flags&flagUnsigned != 0 {
1165+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11651166
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
11661167
} else {
11671168
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
@@ -1170,7 +1171,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11701171
continue
11711172

11721173
case fieldTypeInt24, fieldTypeLong:
1173-
if rows.columns[i].flags&flagUnsigned != 0 {
1174+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11741175
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
11751176
} else {
11761177
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
@@ -1179,7 +1180,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11791180
continue
11801181

11811182
case fieldTypeLongLong:
1182-
if rows.columns[i].flags&flagUnsigned != 0 {
1183+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
11831184
val := binary.LittleEndian.Uint64(data[pos : pos+8])
11841185
if val > math.MaxInt64 {
11851186
dest[i] = uint64ToString(val)
@@ -1233,37 +1234,37 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
12331234
case isNull:
12341235
dest[i] = nil
12351236
continue
1236-
case rows.columns[i].fieldType == fieldTypeTime:
1237+
case rows.rs.columns[i].fieldType == fieldTypeTime:
12371238
// database/sql does not support an equivalent to TIME, return a string
12381239
var dstlen uint8
1239-
switch decimals := rows.columns[i].decimals; decimals {
1240+
switch decimals := rows.rs.columns[i].decimals; decimals {
12401241
case 0x00, 0x1f:
12411242
dstlen = 8
12421243
case 1, 2, 3, 4, 5, 6:
12431244
dstlen = 8 + 1 + decimals
12441245
default:
12451246
return fmt.Errorf(
12461247
"protocol error, illegal decimals value %d",
1247-
rows.columns[i].decimals,
1248+
rows.rs.columns[i].decimals,
12481249
)
12491250
}
12501251
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
12511252
case rows.mc.parseTime:
12521253
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
12531254
default:
12541255
var dstlen uint8
1255-
if rows.columns[i].fieldType == fieldTypeDate {
1256+
if rows.rs.columns[i].fieldType == fieldTypeDate {
12561257
dstlen = 10
12571258
} else {
1258-
switch decimals := rows.columns[i].decimals; decimals {
1259+
switch decimals := rows.rs.columns[i].decimals; decimals {
12591260
case 0x00, 0x1f:
12601261
dstlen = 19
12611262
case 1, 2, 3, 4, 5, 6:
12621263
dstlen = 19 + 1 + decimals
12631264
default:
12641265
return fmt.Errorf(
12651266
"protocol error, illegal decimals value %d",
1266-
rows.columns[i].decimals,
1267+
rows.rs.columns[i].decimals,
12671268
)
12681269
}
12691270
}
@@ -1279,7 +1280,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
12791280

12801281
// Please report if this happens!
12811282
default:
1282-
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
1283+
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
12831284
}
12841285
}
12851286

rows.go

+52-7
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@ type mysqlField struct {
2121
decimals byte
2222
}
2323

24-
type mysqlRows struct {
25-
mc *mysqlConn
24+
type resultSet struct {
2625
columns []mysqlField
26+
done bool
27+
}
28+
29+
type mysqlRows struct {
30+
mc *mysqlConn
31+
rs *resultSet
2732
}
2833

2934
type binaryRows struct {
@@ -34,21 +39,33 @@ type textRows struct {
3439
mysqlRows
3540
}
3641

42+
func newTextRows(mc *mysqlConn) *textRows {
43+
return &textRows{
44+
mysqlRows{
45+
mc: mc,
46+
rs: new(resultSet),
47+
},
48+
}
49+
}
50+
3751
type emptyRows struct{}
3852

3953
func (rows *mysqlRows) Columns() []string {
40-
columns := make([]string, len(rows.columns))
54+
if rows.rs == nil {
55+
return []string{}
56+
}
57+
columns := make([]string, len(rows.rs.columns))
4158
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
4259
for i := range columns {
43-
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
44-
columns[i] = tableName + "." + rows.columns[i].name
60+
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
61+
columns[i] = tableName + "." + rows.rs.columns[i].name
4562
} else {
46-
columns[i] = rows.columns[i].name
63+
columns[i] = rows.rs.columns[i].name
4764
}
4865
}
4966
} else {
5067
for i := range columns {
51-
columns[i] = rows.columns[i].name
68+
columns[i] = rows.rs.columns[i].name
5269
}
5370
}
5471
return columns
@@ -99,6 +116,26 @@ func (rows *textRows) Next(dest []driver.Value) error {
99116
return io.EOF
100117
}
101118

119+
func (rows *textRows) HasNextResultSet() (b bool) {
120+
if rows.mc == nil {
121+
return false
122+
}
123+
return rows.mc.status&statusMoreResultsExists != 0
124+
}
125+
126+
func (rows *textRows) NextResultSet() error {
127+
if !rows.HasNextResultSet() {
128+
return io.EOF
129+
}
130+
rows.rs = new(resultSet)
131+
resLen, err := rows.mc.readResultSetHeaderPacket()
132+
if err != nil {
133+
return err
134+
}
135+
rows.rs.columns, err = rows.mc.readColumns(resLen)
136+
return err
137+
}
138+
102139
func (rows emptyRows) Columns() []string {
103140
return nil
104141
}
@@ -110,3 +147,11 @@ func (rows emptyRows) Close() error {
110147
func (rows emptyRows) Next(dest []driver.Value) error {
111148
return io.EOF
112149
}
150+
151+
func (rows emptyRows) HasNextResultSet() bool {
152+
return false
153+
}
154+
155+
func (rows emptyRows) NextResultSet() error {
156+
return io.EOF
157+
}

statement.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,19 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
103103
return nil, err
104104
}
105105

106-
rows := new(binaryRows)
106+
rows := &binaryRows{
107+
mysqlRows{rs: new(resultSet)},
108+
}
107109

108110
if resLen > 0 {
109111
rows.mc = mc
110112
// Columns
111113
// If not cached, read them and cache them
112114
if stmt.columns == nil {
113-
rows.columns, err = mc.readColumns(resLen)
114-
stmt.columns = rows.columns
115+
rows.rs.columns, err = mc.readColumns(resLen)
116+
stmt.columns = rows.rs.columns
115117
} else {
116-
rows.columns = stmt.columns
118+
rows.rs.columns = stmt.columns
117119
err = mc.readUntilEOF()
118120
}
119121
}

0 commit comments

Comments
 (0)