From c823aa0d52bc8a99856ea8a382264e3951e540da Mon Sep 17 00:00:00 2001 From: jszwec Date: Sun, 8 Jan 2017 20:46:50 -0500 Subject: [PATCH 1/4] Add Multi-Results support Fixes #420 --- AUTHORS | 1 + connection.go | 4 +- driver_go18_test.go | 100 ++++++++++++++++++++++++++++++++++++++++++++ packets.go | 44 +++++++++---------- rows.go | 69 ++++++++++++++++++++++++++---- statement.go | 6 +-- 6 files changed, 189 insertions(+), 35 deletions(-) create mode 100644 driver_go18_test.go diff --git a/AUTHORS b/AUTHORS index 100370758..f47d35d96 100644 --- a/AUTHORS +++ b/AUTHORS @@ -25,6 +25,7 @@ Hanno Braun Henri Yandell Hirotaka Yamamoto INADA Naoki +Jacek Szwec James Harr Jian Zhen Joshua Prunier diff --git a/connection.go b/connection.go index d82c728f3..c9d662daf 100644 --- a/connection.go +++ b/connection.go @@ -339,7 +339,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return emptyRows{}, nil } // Columns - rows.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(resLen) return rows, err } } @@ -359,7 +359,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if err == nil { rows := new(textRows) rows.mc = mc - rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} + rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns diff --git a/driver_go18_test.go b/driver_go18_test.go new file mode 100644 index 000000000..24ad51a55 --- /dev/null +++ b/driver_go18_test.go @@ -0,0 +1,100 @@ +// +build go1.8 + +package mysql + +import ( + "reflect" + "testing" +) + +func TestMultiResultSet(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + type result struct { + values [][]int + columns []string + } + + expected := []result{ + { + values: [][]int{{1, 2}, {3, 4}}, + columns: []string{"col1", "col2"}, + }, + { + values: [][]int{{1, 2, 3}, {4, 5, 6}}, + columns: []string{"col1", "col2", "col3"}, + }, + } + + query := ` +SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; +SELECT 0 UNION SELECT 1; -- ignore this result set +SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;` + + rows := dbt.mustQuery(query) + defer rows.Close() + + var res1 result + for rows.Next() { + var res [2]int + if err := rows.Scan(&res[0], &res[1]); err != nil { + dbt.Fatal(err) + } + res1.values = append(res1.values, res[:]) + } + + if rows.Next() { + dbt.Error("unexpected row") + } + + cols, err := rows.Columns() + if err != nil { + dbt.Fatal(err) + } + res1.columns = cols + + if !reflect.DeepEqual(expected[0], res1) { + dbt.Error("want =", expected[0], "got =", res1) + } + + if !rows.NextResultSet() { + dbt.Fatal("expected next result set") + } + + // ignoring one result set + + if !rows.NextResultSet() { + dbt.Fatal("expected next result set") + } + + var res2 result + cols, err = rows.Columns() + if err != nil { + dbt.Fatal(err) + } + res2.columns = cols + + for rows.Next() { + var res [3]int + if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil { + dbt.Fatal(err) + } + res2.values = append(res2.values, res[:]) + } + + if !reflect.DeepEqual(expected[1], res2) { + dbt.Error("want =", expected[1], "got =", res2) + } + + if rows.Next() { + dbt.Error("unexpected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error(err) + } + }) +} diff --git a/packets.go b/packets.go index aafe9793e..85289a2ef 100644 --- a/packets.go +++ b/packets.go @@ -231,6 +231,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientTransactions | clientLocalFiles | clientPluginAuth | + clientMultiStatements | clientMultiResults | mc.flags&clientLongFlag @@ -698,6 +699,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc + if rows.rs.done { + return io.EOF + } + data, err := mc.readPacket() if err != nil { return err @@ -707,15 +712,11 @@ func (rows *textRows) readRow(dest []driver.Value) error { if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) - err = rows.mc.discardResults() - if err == nil { - err = io.EOF - } else { - // connection unusable - rows.mc.Close() + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil } - rows.mc = nil - return err + return io.EOF } if data[0] == iERR { rows.mc = nil @@ -736,7 +737,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { if !mc.parseTime { continue } else { - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( @@ -1116,6 +1117,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) + rows.rs.done = true err = rows.mc.discardResults() if err == nil { err = io.EOF @@ -1145,14 +1147,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } // Convert to byte-coded string - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeNULL: dest[i] = nil continue // Numeric Types case fieldTypeTiny: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) @@ -1161,7 +1163,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeShort, fieldTypeYear: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) @@ -1170,7 +1172,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeInt24, fieldTypeLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) @@ -1179,7 +1181,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeLongLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) @@ -1233,10 +1235,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case isNull: dest[i] = nil continue - case rows.columns[i].fieldType == fieldTypeTime: + case rows.rs.columns[i].fieldType == fieldTypeTime: // database/sql does not support an equivalent to TIME, return a string var dstlen uint8 - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 8 case 1, 2, 3, 4, 5, 6: @@ -1244,7 +1246,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) @@ -1252,10 +1254,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 - if rows.columns[i].fieldType == fieldTypeDate { + if rows.rs.columns[i].fieldType == fieldTypeDate { dstlen = 10 } else { - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 19 case 1, 2, 3, 4, 5, 6: @@ -1263,7 +1265,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } } @@ -1279,7 +1281,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // Please report if this happens! default: - return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) + return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) } } diff --git a/rows.go b/rows.go index c08255eee..b230a7955 100644 --- a/rows.go +++ b/rows.go @@ -21,9 +21,14 @@ type mysqlField struct { decimals byte } -type mysqlRows struct { - mc *mysqlConn +type resultSet struct { columns []mysqlField + done bool +} + +type mysqlRows struct { + mc *mysqlConn + rs resultSet } type binaryRows struct { @@ -37,24 +42,24 @@ type textRows struct { type emptyRows struct{} func (rows *mysqlRows) Columns() []string { - columns := make([]string, len(rows.columns)) + columns := make([]string, len(rows.rs.columns)) if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { for i := range columns { - if tableName := rows.columns[i].tableName; len(tableName) > 0 { - columns[i] = tableName + "." + rows.columns[i].name + if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.rs.columns[i].name } else { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } } else { for i := range columns { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } return columns } -func (rows *mysqlRows) Close() error { +func (rows *mysqlRows) Close() (err error) { mc := rows.mc if mc == nil { return nil @@ -64,7 +69,9 @@ func (rows *mysqlRows) Close() error { } // Remove unread packets from stream - err := mc.readUntilEOF() + if !rows.rs.done { + err = mc.readUntilEOF() + } if err == nil { if err = mc.discardResults(); err != nil { return err @@ -99,6 +106,42 @@ func (rows *textRows) Next(dest []driver.Value) error { return io.EOF } +func (rows *textRows) HasNextResultSet() (b bool) { + if rows.mc == nil { + return false + } + return rows.mc.status&statusMoreResultsExists != 0 +} + +func (rows *textRows) NextResultSet() error { + if rows.mc == nil { + return io.EOF + } + if rows.mc.netConn == nil { + return ErrInvalidConn + } + + // Remove unread packets from stream + if !rows.rs.done { + if err := rows.mc.readUntilEOF(); err != nil { + return err + } + } + + if !rows.HasNextResultSet() { + return io.EOF + } + rows.rs = resultSet{} + + resLen, err := rows.mc.readResultSetHeaderPacket() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + func (rows emptyRows) Columns() []string { return nil } @@ -110,3 +153,11 @@ func (rows emptyRows) Close() error { func (rows emptyRows) Next(dest []driver.Value) error { return io.EOF } + +func (rows emptyRows) HasNextResultSet() bool { + return false +} + +func (rows emptyRows) NextResultSet() error { + return io.EOF +} diff --git a/statement.go b/statement.go index 7f9b04585..2e890107d 100644 --- a/statement.go +++ b/statement.go @@ -110,10 +110,10 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { // Columns // If not cached, read them and cache them if stmt.columns == nil { - rows.columns, err = mc.readColumns(resLen) - stmt.columns = rows.columns + rows.rs.columns, err = mc.readColumns(resLen) + stmt.columns = rows.rs.columns } else { - rows.columns = stmt.columns + rows.rs.columns = stmt.columns err = mc.readUntilEOF() } } From 0ac948351b843ef2f14f3f6d708aaa92980d1c34 Mon Sep 17 00:00:00 2001 From: jszwec Date: Sun, 15 Jan 2017 00:50:13 -0500 Subject: [PATCH 2/4] Multi-Results improvements - support for binary protocol - support statements returning no results - remove emptyRows --- connection.go | 30 ++++++++--- driver_go18_test.go | 121 ++++++++++++++++++++++++++++++++++---------- packets.go | 18 ++----- rows.go | 106 ++++++++++++++++++++++---------------- statement.go | 54 +++++++++++++------- 5 files changed, 219 insertions(+), 110 deletions(-) diff --git a/connection.go b/connection.go index c9d662daf..08e5fadeb 100644 --- a/connection.go +++ b/connection.go @@ -10,6 +10,7 @@ package mysql import ( "database/sql/driver" + "io" "net" "strconv" "strings" @@ -289,22 +290,29 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command - err := mc.writeCommandPacketStr(comQuery, query) - if err != nil { + if err := mc.writeCommandPacketStr(comQuery, query); err != nil { return err } // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil && resLen > 0 { - if err = mc.readUntilEOF(); err != nil { + if err != nil { + return err + } + + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { return err } - err = mc.readUntilEOF() + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } } - return err + return mc.discardResults() } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { @@ -335,8 +343,14 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro rows.mc = mc if resLen == 0 { - // no columns, no more data - return emptyRows{}, nil + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } } // Columns rows.rs.columns, err = mc.readColumns(resLen) diff --git a/driver_go18_test.go b/driver_go18_test.go index 24ad51a55..c836c91da 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -3,17 +3,28 @@ package mysql import ( + "database/sql" + "fmt" "reflect" "testing" ) func TestMultiResultSet(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - type result struct { - values [][]int - columns []string - } - + type result struct { + values [][]int + columns []string + } + + // checkRows is a helper test function to validate rows containing 3 result + // sets with specific values and columns. The basic query would look like this: + // + // SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + // SELECT 0 UNION SELECT 1; + // SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + // + // to distinguish test cases the first string argument is put in front of + // every error or fatal message. + checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) { expected := []result{ { values: [][]int{{1, 2}, {3, 4}}, @@ -25,14 +36,6 @@ func TestMultiResultSet(t *testing.T) { }, } - query := ` -SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; -SELECT 0 UNION SELECT 1; -- ignore this result set -SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;` - - rows := dbt.mustQuery(query) - defer rows.Close() - var res1 result for rows.Next() { var res [2]int @@ -42,49 +45,115 @@ SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;` res1.values = append(res1.values, res[:]) } - if rows.Next() { - dbt.Error("unexpected row") - } - cols, err := rows.Columns() if err != nil { - dbt.Fatal(err) + dbt.Fatal(desc, err) } res1.columns = cols if !reflect.DeepEqual(expected[0], res1) { - dbt.Error("want =", expected[0], "got =", res1) + dbt.Error(desc, "want =", expected[0], "got =", res1) } if !rows.NextResultSet() { - dbt.Fatal("expected next result set") + dbt.Fatal(desc, "expected next result set") } // ignoring one result set if !rows.NextResultSet() { - dbt.Fatal("expected next result set") + dbt.Fatal(desc, "expected next result set") } var res2 result cols, err = rows.Columns() if err != nil { - dbt.Fatal(err) + dbt.Fatal(desc, err) } res2.columns = cols for rows.Next() { var res [3]int if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil { - dbt.Fatal(err) + dbt.Fatal(desc, err) } res2.values = append(res2.values, res[:]) } if !reflect.DeepEqual(expected[1], res2) { - dbt.Error("want =", expected[1], "got =", res2) + dbt.Error(desc, "want =", expected[1], "got =", res2) + } + + if rows.NextResultSet() { + dbt.Error(desc, "unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error(desc, err) + } + } + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery(`DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`) + defer rows.Close() + checkRows("query: ", rows, dbt) + }) + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + queries := []string{ + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, } + defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss") + + for i, query := range queries { + dbt.mustExec(query) + + stmt, err := dbt.db.Prepare("CALL test_mrss()") + if err != nil { + dbt.Fatalf("%v (i=%d)", err, i) + } + defer stmt.Close() + + for j := 0; j < 2; j++ { + rows, err := stmt.Query() + if err != nil { + dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j) + } + checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt) + } + } + }) +} + +func TestMultiResultSetNoSelect(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("DO 1; DO 2;") + defer rows.Close() + if rows.Next() { dbt.Error("unexpected row") } @@ -94,7 +163,7 @@ SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;` } if err := rows.Err(); err != nil { - dbt.Error(err) + dbt.Error("expected nil; got ", err) } }) } diff --git a/packets.go b/packets.go index 85289a2ef..41b4d3d55 100644 --- a/packets.go +++ b/packets.go @@ -231,7 +231,6 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientTransactions | clientLocalFiles | clientPluginAuth | - clientMultiStatements | clientMultiResults | mc.flags&clientLongFlag @@ -585,8 +584,8 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) - if err := mc.discardResults(); err != nil { - return err + if mc.status&statusMoreResultsExists != 0 { + return nil } // warning count [2 bytes] @@ -1098,8 +1097,6 @@ func (mc *mysqlConn) discardResults() error { if err := mc.readUntilEOF(); err != nil { return err } - } else { - mc.status &^= statusMoreResultsExists } } return nil @@ -1118,15 +1115,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) rows.rs.done = true - err = rows.mc.discardResults() - if err == nil { - err = io.EOF - } else { - // connection unusable - rows.mc.Close() + if !rows.HasNextResultSet() { + rows.mc = nil } - rows.mc = nil - return err + return io.EOF } rows.mc = nil diff --git a/rows.go b/rows.go index b230a7955..b6403f211 100644 --- a/rows.go +++ b/rows.go @@ -33,14 +33,18 @@ type mysqlRows struct { type binaryRows struct { mysqlRows + // stmtCols is a pointer to the statement's cached columns for different + // result sets. + stmtCols *[][]mysqlField + // i is a number of the current result set. It is used to fetch proper + // columns from stmtCols. + i int } type textRows struct { mysqlRows } -type emptyRows struct{} - func (rows *mysqlRows) Columns() []string { columns := make([]string, len(rows.rs.columns)) if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { @@ -82,82 +86,96 @@ func (rows *mysqlRows) Close() (err error) { return err } -func (rows *binaryRows) Next(dest []driver.Value) error { - if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn - } - - // Fetch next row from stream - return rows.readRow(dest) - } - return io.EOF -} - -func (rows *textRows) Next(dest []driver.Value) error { - if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn - } - - // Fetch next row from stream - return rows.readRow(dest) - } - return io.EOF -} - -func (rows *textRows) HasNextResultSet() (b bool) { +func (rows *mysqlRows) HasNextResultSet() (b bool) { if rows.mc == nil { return false } return rows.mc.status&statusMoreResultsExists != 0 } -func (rows *textRows) NextResultSet() error { +func (rows *mysqlRows) nextResultSet() (int, error) { if rows.mc == nil { - return io.EOF + return 0, io.EOF } if rows.mc.netConn == nil { - return ErrInvalidConn + return 0, ErrInvalidConn } // Remove unread packets from stream if !rows.rs.done { if err := rows.mc.readUntilEOF(); err != nil { - return err + return 0, err } } if !rows.HasNextResultSet() { - return io.EOF + return 0, io.EOF } rows.rs = resultSet{} + return rows.mc.readResultSetHeaderPacket() +} - resLen, err := rows.mc.readResultSetHeaderPacket() +func (rows *binaryRows) NextResultSet() error { + resLen, err := rows.nextResultSet() if err != nil { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) - return err -} + if resLen == 0 { + rows.rs.done = true + return rows.NextResultSet() + } -func (rows emptyRows) Columns() []string { - return nil -} + // get columns, if not cached, read them and cache them. + if rows.i >= len(*rows.stmtCols) { + rows.rs.columns, err = rows.mc.readColumns(resLen) + *rows.stmtCols = append(*rows.stmtCols, rows.rs.columns) + } else { + rows.rs.columns = (*rows.stmtCols)[rows.i] + if err := rows.mc.readUntilEOF(); err != nil { + return err + } + } -func (rows emptyRows) Close() error { + rows.i++ return nil } -func (rows emptyRows) Next(dest []driver.Value) error { +func (rows *binaryRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if mc.netConn == nil { + return ErrInvalidConn + } + + // Fetch next row from stream + return rows.readRow(dest) + } return io.EOF } -func (rows emptyRows) HasNextResultSet() bool { - return false +func (rows *textRows) NextResultSet() error { + resLen, err := rows.nextResultSet() + if err != nil { + return err + } + + if resLen == 0 { + rows.rs.done = true + return rows.NextResultSet() + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err } -func (rows emptyRows) NextResultSet() error { +func (rows *textRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if mc.netConn == nil { + return ErrInvalidConn + } + + // Fetch next row from stream + return rows.readRow(dest) + } return io.EOF } diff --git a/statement.go b/statement.go index 2e890107d..b88771674 100644 --- a/statement.go +++ b/statement.go @@ -11,6 +11,7 @@ package mysql import ( "database/sql/driver" "fmt" + "io" "reflect" "strconv" ) @@ -19,7 +20,7 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int - columns []mysqlField // cached from the first query + columns [][]mysqlField // cached from the first query } func (stmt *mysqlStmt) Close() error { @@ -62,26 +63,30 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil { - if resLen > 0 { - // Columns - err = mc.readUntilEOF() - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - // Rows - err = mc.readUntilEOF() + if resLen > 0 { + // Columns + if err = mc.readUntilEOF(); err != nil { + return nil, err } - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, nil + + // Rows + if err := mc.readUntilEOF(); err != nil { + return nil, err } } - return nil, err + if err := mc.discardResults(); err != nil { + return nil, err + } + + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { @@ -104,18 +109,29 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { } rows := new(binaryRows) + rows.stmtCols = &stmt.columns if resLen > 0 { rows.mc = mc + rows.i++ // Columns // If not cached, read them and cache them - if stmt.columns == nil { + if len(stmt.columns) == 0 { rows.rs.columns, err = mc.readColumns(resLen) - stmt.columns = rows.rs.columns + stmt.columns = append(stmt.columns, rows.rs.columns) } else { - rows.rs.columns = stmt.columns + rows.rs.columns = stmt.columns[0] err = mc.readUntilEOF() } + } else { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } } return rows, err From 8b062e70ddf7da8877cdcee424cf6c88ff72c307 Mon Sep 17 00:00:00 2001 From: jszwec Date: Sun, 15 Jan 2017 11:50:06 -0500 Subject: [PATCH 3/4] Multi-Results: fix hanging rows.Close() * rows.Close() would hang on readUntilEOF if some results were ignored before calling NextResultSet() --- driver_go18_test.go | 21 +++++++++++++++++++++ rows.go | 2 ++ 2 files changed, 23 insertions(+) diff --git a/driver_go18_test.go b/driver_go18_test.go index c836c91da..93918ad46 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -167,3 +167,24 @@ func TestMultiResultSetNoSelect(t *testing.T) { } }) } + +// tests if rows are set in a proper state if some results were ignored before +// calling rows.NextResultSet. +func TestSkipResults(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1, 2") + defer rows.Close() + + if !rows.Next() { + dbt.Error("expected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error("expected nil; got ", err) + } + }) +} diff --git a/rows.go b/rows.go index b6403f211..b37516891 100644 --- a/rows.go +++ b/rows.go @@ -106,9 +106,11 @@ func (rows *mysqlRows) nextResultSet() (int, error) { if err := rows.mc.readUntilEOF(); err != nil { return 0, err } + rows.rs.done = true } if !rows.HasNextResultSet() { + rows.mc = nil return 0, io.EOF } rows.rs = resultSet{} From 0e119f8774fa708b646edaa3220353a096b83e81 Mon Sep 17 00:00:00 2001 From: jszwec Date: Sun, 15 Jan 2017 13:41:56 -0500 Subject: [PATCH 4/4] Multi-Results: review suggestion - get rid of a recursive call --- rows.go | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/rows.go b/rows.go index b37516891..900f548ae 100644 --- a/rows.go +++ b/rows.go @@ -117,15 +117,25 @@ func (rows *mysqlRows) nextResultSet() (int, error) { return rows.mc.readResultSetHeaderPacket() } -func (rows *binaryRows) NextResultSet() error { - resLen, err := rows.nextResultSet() - if err != nil { - return err - } +func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { + for { + resLen, err := rows.nextResultSet() + if err != nil { + return 0, err + } + + if resLen > 0 { + return resLen, nil + } - if resLen == 0 { rows.rs.done = true - return rows.NextResultSet() + } +} + +func (rows *binaryRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err } // get columns, if not cached, read them and cache them. @@ -155,17 +165,12 @@ func (rows *binaryRows) Next(dest []driver.Value) error { return io.EOF } -func (rows *textRows) NextResultSet() error { - resLen, err := rows.nextResultSet() +func (rows *textRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() if err != nil { return err } - if resLen == 0 { - rows.rs.done = true - return rows.NextResultSet() - } - rows.rs.columns, err = rows.mc.readColumns(resLen) return err }