Skip to content

Commit 0c50208

Browse files
committed
rows: Invalidate connection on error in discardResults()
Fixes #422
1 parent 0926834 commit 0c50208

File tree

2 files changed

+61
-6
lines changed

2 files changed

+61
-6
lines changed

driver_test.go

+47
Original file line numberDiff line numberDiff line change
@@ -1855,3 +1855,50 @@ func TestUnixSocketAuthFail(t *testing.T) {
18551855
}
18561856
})
18571857
}
1858+
1859+
// See Issue #422
1860+
func TestInterruptBySignal(t *testing.T) {
1861+
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
1862+
dbt.mustExec(`
1863+
DROP PROCEDURE IF EXISTS test_signal;
1864+
CREATE PROCEDURE test_signal(ret INT)
1865+
BEGIN
1866+
SELECT ret;
1867+
SIGNAL SQLSTATE
1868+
'45001'
1869+
SET
1870+
MESSAGE_TEXT = "an error",
1871+
MYSQL_ERRNO = 45001;
1872+
END
1873+
`)
1874+
defer dbt.mustExec("DROP PROCEDURE test_signal")
1875+
1876+
var val int
1877+
1878+
// text protocol
1879+
rows, err := dbt.db.Query("CALL test_signal(42)")
1880+
if err != nil {
1881+
dbt.Fatalf("error on text query: %s", err.Error())
1882+
}
1883+
for rows.Next() {
1884+
if err := rows.Scan(&val); err != nil {
1885+
dbt.Error(err)
1886+
} else if val != 42 {
1887+
dbt.Errorf("expected val to be 42")
1888+
}
1889+
}
1890+
1891+
// binary protocol
1892+
rows, err = dbt.db.Query("CALL test_signal(?)", 42)
1893+
if err != nil {
1894+
dbt.Fatalf("error on binary query: %s", err.Error())
1895+
}
1896+
for rows.Next() {
1897+
if err := rows.Scan(&val); err != nil {
1898+
dbt.Error(err)
1899+
} else if val != 42 {
1900+
dbt.Errorf("expected val to be 42")
1901+
}
1902+
}
1903+
})
1904+
}

packets.go

+14-6
Original file line numberDiff line numberDiff line change
@@ -700,11 +700,15 @@ func (rows *textRows) readRow(dest []driver.Value) error {
700700
if data[0] == iEOF && len(data) == 5 {
701701
// server_status [2 bytes]
702702
rows.mc.status = readStatus(data[3:])
703-
if err := rows.mc.discardResults(); err != nil {
704-
return err
703+
err = rows.mc.discardResults()
704+
if err == nil {
705+
err = io.EOF
706+
} else {
707+
// connection unusable
708+
rows.mc.Close()
705709
}
706710
rows.mc = nil
707-
return io.EOF
711+
return err
708712
}
709713
if data[0] == iERR {
710714
rows.mc = nil
@@ -1105,11 +1109,15 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11051109
// EOF Packet
11061110
if data[0] == iEOF && len(data) == 5 {
11071111
rows.mc.status = readStatus(data[3:])
1108-
if err := rows.mc.discardResults(); err != nil {
1109-
return err
1112+
err = rows.mc.discardResults()
1113+
if err == nil {
1114+
err = io.EOF
1115+
} else {
1116+
// connection unusable
1117+
rows.mc.Close()
11101118
}
11111119
rows.mc = nil
1112-
return io.EOF
1120+
return err
11131121
}
11141122
rows.mc = nil
11151123

0 commit comments

Comments
 (0)