Skip to content

Commit b432cc0

Browse files
committed
TestMultiQuery PASS
discard additional OK response after Multi Statement Exec Calls
1 parent 2ebd1ef commit b432cc0

File tree

3 files changed

+114
-4
lines changed

3 files changed

+114
-4
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Runrioter Wung <runrioter at gmail.com>
3434
Soroush Pour <me at soroushjp.com>
3535
Xiaobing Jiang <s7v7nislands at gmail.com>
3636
Xiuming Chen <cc at cxm.cc>
37+
Stanley Gunawan <gunawan.stanley at gmail.com>
3738

3839
# Organizations
3940

driver_test.go

+84-3
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ func init() {
5555
}
5656
return defaultValue
5757
}
58-
user = env("MYSQL_TEST_USER", "root")
59-
pass = env("MYSQL_TEST_PASS", "")
58+
user = env("MYSQL_TEST_USER", "galaxy")
59+
pass = env("MYSQL_TEST_PASS", "nPyOfX80gHQ2R7gFzo2t")
6060
prot = env("MYSQL_TEST_PROT", "tcp")
61-
addr = env("MYSQL_TEST_ADDR", "localhost:3306")
61+
addr = env("MYSQL_TEST_ADDR", "107.167.186.193:3306")
6262
dbname = env("MYSQL_TEST_DBNAME", "gotest")
6363
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
6464
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
@@ -74,6 +74,28 @@ type DBTest struct {
7474
db *sql.DB
7575
}
7676

77+
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
78+
if !available {
79+
t.Skipf("MySQL-Server not running on %s", netAddr)
80+
}
81+
82+
dsn3 := dsn + "&multiStatements=true"
83+
var db3 *sql.DB
84+
if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
85+
db3, err = sql.Open("mysql", dsn3)
86+
if err != nil {
87+
t.Fatalf("Error connecting: %s", err.Error())
88+
}
89+
defer db3.Close()
90+
}
91+
92+
dbt3 := &DBTest{t, db3}
93+
for _, test := range tests {
94+
test(dbt3)
95+
dbt3.db.Exec("DROP TABLE IF EXISTS test")
96+
}
97+
}
98+
7799
func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
78100
if !available {
79101
t.Skipf("MySQL-Server not running on %s", netAddr)
@@ -97,15 +119,30 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
97119
defer db2.Close()
98120
}
99121

122+
dsn3 := dsn + "&multiStatements=true"
123+
var db3 *sql.DB
124+
if _, err := parseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
125+
db3, err = sql.Open("mysql", dsn3)
126+
if err != nil {
127+
t.Fatalf("Error connecting: %s", err.Error())
128+
}
129+
defer db3.Close()
130+
}
131+
100132
dbt := &DBTest{t, db}
101133
dbt2 := &DBTest{t, db2}
134+
dbt3 := &DBTest{t, db3}
102135
for _, test := range tests {
103136
test(dbt)
104137
dbt.db.Exec("DROP TABLE IF EXISTS test")
105138
if db2 != nil {
106139
test(dbt2)
107140
dbt2.db.Exec("DROP TABLE IF EXISTS test")
108141
}
142+
if db3 != nil {
143+
test(dbt3)
144+
dbt3.db.Exec("DROP TABLE IF EXISTS test")
145+
}
109146
}
110147
}
111148

@@ -235,6 +272,50 @@ func TestCRUD(t *testing.T) {
235272
})
236273
}
237274

275+
func TestMultiQuery(t *testing.T) {
276+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
277+
// Create Table
278+
dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")
279+
280+
// Create Data
281+
res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
282+
count, err := res.RowsAffected()
283+
if err != nil {
284+
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
285+
}
286+
if count != 1 {
287+
dbt.Fatalf("Expected 1 affected row, got %d", count)
288+
}
289+
290+
// Update
291+
res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1;")
292+
count, err = res.RowsAffected()
293+
if err != nil {
294+
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
295+
}
296+
if count != 1 {
297+
dbt.Fatalf("Expected 1 affected row, got %d", count)
298+
}
299+
300+
// Read
301+
var out int
302+
rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
303+
if rows.Next() {
304+
rows.Scan(&out)
305+
if 4 != out {
306+
dbt.Errorf("4 != %t", out)
307+
}
308+
309+
if rows.Next() {
310+
dbt.Error("unexpected data")
311+
}
312+
} else {
313+
dbt.Error("no data")
314+
}
315+
316+
})
317+
}
318+
238319
func TestInt(t *testing.T) {
239320
runTests(t, dsn, func(dbt *DBTest) {
240321
types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}

packets.go

+29-1
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,32 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
474474
}
475475
}
476476

477+
func readStatus(b []byte) statusFlag {
478+
return statusFlag(b[0]) | statusFlag(b[1])<<8
479+
}
480+
481+
func (mc *mysqlConn) discardMoreResultsIfExists() error {
482+
for mc.status&statusMoreResultsExists != 0 {
483+
resLen, err := mc.readResultSetHeaderPacket()
484+
if err != nil {
485+
return err
486+
}
487+
if resLen > 0 {
488+
// columns
489+
if err := mc.readUntilEOF(); err != nil {
490+
return err
491+
}
492+
// rows
493+
if err := mc.readUntilEOF(); err != nil {
494+
return err
495+
}
496+
} else {
497+
mc.status &^= statusMoreResultsExists
498+
}
499+
}
500+
return nil
501+
}
502+
477503
// Ok Packet
478504
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
479505
func (mc *mysqlConn) handleOkPacket(data []byte) error {
@@ -488,7 +514,9 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
488514
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
489515

490516
// server_status [2 bytes]
491-
mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
517+
// mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
518+
mc.status = readStatus(data[1+n+m : 1+n+m+2])
519+
mc.discardMoreResultsIfExists()
492520

493521
// warning count [2 bytes]
494522
if !mc.strict {

0 commit comments

Comments
 (0)