Skip to content

Add Multi-Results support #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Hirotaka Yamamoto <ymmt2005 at gmail.com>
INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Joshua Prunier <joshua.prunier at gmail.com>
Expand Down
34 changes: 24 additions & 10 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"database/sql/driver"
"io"
"net"
"strconv"
"strings"
Expand Down Expand Up @@ -289,22 +290,29 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err != nil {
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return err
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil && resLen > 0 {
if err = mc.readUntilEOF(); err != nil {
if err != nil {
return err
}

if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
return err
}

err = mc.readUntilEOF()
// rows
if err := mc.readUntilEOF(); err != nil {
return err
}
}

return err
return mc.discardResults()
}

func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
Expand Down Expand Up @@ -335,11 +343,17 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
rows.mc = mc

if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
rows.rs.done = true

switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
// Columns
rows.columns, err = mc.readColumns(resLen)
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}
}
Expand All @@ -359,7 +373,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}

if resLen > 0 {
// Columns
Expand Down
190 changes: 190 additions & 0 deletions driver_go18_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// +build go1.8

package mysql

import (
"database/sql"
"fmt"
"reflect"
"testing"
)

func TestMultiResultSet(t *testing.T) {
type result struct {
values [][]int
columns []string
}

// checkRows is a helper test function to validate rows containing 3 result
// sets with specific values and columns. The basic query would look like this:
//
// SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
// SELECT 0 UNION SELECT 1;
// SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
//
// to distinguish test cases the first string argument is put in front of
// every error or fatal message.
checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) {
expected := []result{
{
values: [][]int{{1, 2}, {3, 4}},
columns: []string{"col1", "col2"},
},
{
values: [][]int{{1, 2, 3}, {4, 5, 6}},
columns: []string{"col1", "col2", "col3"},
},
}

var res1 result
for rows.Next() {
var res [2]int
if err := rows.Scan(&res[0], &res[1]); err != nil {
dbt.Fatal(err)
}
res1.values = append(res1.values, res[:])
}

cols, err := rows.Columns()
if err != nil {
dbt.Fatal(desc, err)
}
res1.columns = cols

if !reflect.DeepEqual(expected[0], res1) {
dbt.Error(desc, "want =", expected[0], "got =", res1)
}

if !rows.NextResultSet() {
dbt.Fatal(desc, "expected next result set")
}

// ignoring one result set

if !rows.NextResultSet() {
dbt.Fatal(desc, "expected next result set")
}

var res2 result
cols, err = rows.Columns()
if err != nil {
dbt.Fatal(desc, err)
}
res2.columns = cols

for rows.Next() {
var res [3]int
if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
dbt.Fatal(desc, err)
}
res2.values = append(res2.values, res[:])
}

if !reflect.DeepEqual(expected[1], res2) {
dbt.Error(desc, "want =", expected[1], "got =", res2)
}

if rows.NextResultSet() {
dbt.Error(desc, "unexpected next result set")
}

if err := rows.Err(); err != nil {
dbt.Error(desc, err)
}
}

runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
rows := dbt.mustQuery(`DO 1;
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
DO 1;
SELECT 0 UNION SELECT 1;
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`)
defer rows.Close()
checkRows("query: ", rows, dbt)
})

runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
queries := []string{
`
DROP PROCEDURE IF EXISTS test_mrss;
CREATE PROCEDURE test_mrss()
BEGIN
DO 1;
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
DO 1;
SELECT 0 UNION SELECT 1;
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
END
`,
`
DROP PROCEDURE IF EXISTS test_mrss;
CREATE PROCEDURE test_mrss()
BEGIN
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
SELECT 0 UNION SELECT 1;
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
END
`,
}

defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss")

for i, query := range queries {
dbt.mustExec(query)

stmt, err := dbt.db.Prepare("CALL test_mrss()")
if err != nil {
dbt.Fatalf("%v (i=%d)", err, i)
}
defer stmt.Close()

for j := 0; j < 2; j++ {
rows, err := stmt.Query()
if err != nil {
dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j)
}
checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt)
}
}
})
}

func TestMultiResultSetNoSelect(t *testing.T) {
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
rows := dbt.mustQuery("DO 1; DO 2;")
defer rows.Close()

if rows.Next() {
dbt.Error("unexpected row")
}

if rows.NextResultSet() {
dbt.Error("unexpected next result set")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why next result set is unexpected?
What happens for DO 1; DO 2; DO 3; SELECT 42?
I expect

if !rows.NextResultSet() { dbt.Error("expect 4 resultset, got 1") }
if !rows.NextResultSet() { dbt.Error("expect 4 resultset, got 2") }
if !rows.NextResultSet() { dbt.Error("expect 4 resultset, got 3") }
var v int64
if err := rows.Scan(&v); err != nil {
    dbt.Error(err)
}
if v != 42 {
   dbt.Error("expected 42; got ", v)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the example from the std:
https://tip.golang.org/pkg/database/sql/#example_DB_Query_multipleResultSets

which seems to ignore the empty results sets.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I got it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what would happen if I have SELECT 1 FROM DUAL WHERE 1 = 0; SELECT 1,2 FROM DUAL WHERE 1 = 1;? That is the first query returns no rows, but the second query does.

I would expect

if !rows.NextResultSet() { dbt.Error("expect 2 resultset, got 1") }

That is, I might not know which query has no resultset. So, I have to go through them all to map processing them correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joegrasse, sorry - "empty results sets" was inaccurate. Result set is present if column count > 0.
So, in the example that you provided it would work as you expected - it would return two result sets, first would be empty, the other one would have one row.

}

if err := rows.Err(); err != nil {
dbt.Error("expected nil; got ", err)
}
})
}

// tests if rows are set in a proper state if some results were ignored before
// calling rows.NextResultSet.
func TestSkipResults(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
rows := dbt.mustQuery("SELECT 1, 2")
defer rows.Close()

if !rows.Next() {
dbt.Error("expected row")
}

if rows.NextResultSet() {
dbt.Error("unexpected next result set")
}

if err := rows.Err(); err != nil {
dbt.Error("expected nil; got ", err)
}
})
}
Loading