Skip to content
This repository was archived by the owner on Jan 21, 2022. It is now read-only.

Commit 1984511

Browse files
committed
Merge remote-tracking branch 'origin/pr/631'
2 parents bf7f34f + a03abe1 commit 1984511

File tree

4 files changed

+102
-3
lines changed

4 files changed

+102
-3
lines changed

AUTHORS

+3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Asta Xie <xiemengjun at gmail.com>
1818
Bulat Gaifullin <gaifullinbf at gmail.com>
1919
Carlos Nieto <jose.carlos at menteslibres.net>
2020
Chris Moos <chris at tech9computers.com>
21+
Chris Piraino <piraino.chris at gmail.com>
2122
Daniel Nichter <nil at codenode.com>
2223
Daniël van Eeden <git at myname.nl>
2324
Dave Protasowski <dprotaso at gmail.com>
@@ -34,6 +35,7 @@ INADA Naoki <songofacandy at gmail.com>
3435
Jacek Szwec <szwec.jacek at gmail.com>
3536
James Harr <james.harr at gmail.com>
3637
Jian Zhen <zhenjl at gmail.com>
38+
John Shahid <jvshahid at gmail.com>
3739
Joshua Prunier <joshua.prunier at gmail.com>
3840
Julien Lefevre <julien.lefevr at gmail.com>
3941
Julien Schmidt <go-sql-driver at julienschmidt.com>
@@ -58,6 +60,7 @@ Runrioter Wung <runrioter at gmail.com>
5860
Soroush Pour <me at soroushjp.com>
5961
Stan Putrya <root.vagner at gmail.com>
6062
Stanley Gunawan <gunawan.stanley at gmail.com>
63+
Thomas Parrott <tomp at tomp.uk>
6164
Xiangyu Hu <xiangyu.hu at outlook.com>
6265
Xiaobing Jiang <s7v7nislands at gmail.com>
6366
Xiuming Chen <cc at cxm.cc>

driver_test.go

+82
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,88 @@ func TestCustomDial(t *testing.T) {
17731773
}
17741774
}
17751775

1776+
// Ensure mariadb's ER_CONNECTION_KILLED will cause the query to be restarted
1777+
func TestConnectionLost(t *testing.T) {
1778+
if !available {
1779+
t.Skipf("MySQL server not running on %s", netAddr)
1780+
}
1781+
1782+
var proxyConn net.Conn
1783+
1784+
killCh := make(chan struct{})
1785+
1786+
// our custom dial function which justs wraps net.Dial here
1787+
RegisterDial("mydial", func(addr string) (net.Conn, error) {
1788+
conn, err := net.Dial(prot, addr)
1789+
if err != nil {
1790+
return nil, err
1791+
}
1792+
1793+
var clientConn net.Conn
1794+
proxyConn, clientConn = net.Pipe()
1795+
go io.Copy(conn, proxyConn)
1796+
1797+
bytesCh := make(chan []byte)
1798+
go func() {
1799+
for {
1800+
bs := make([]byte, 1024)
1801+
n, err := conn.Read(bs)
1802+
if err == io.EOF {
1803+
return
1804+
}
1805+
if err != nil {
1806+
panic(err)
1807+
}
1808+
bytesCh <- bs[:n]
1809+
}
1810+
}()
1811+
go func() {
1812+
for {
1813+
select {
1814+
case bs := <-bytesCh:
1815+
_, err := proxyConn.Write(bs)
1816+
if err == io.ErrClosedPipe {
1817+
return
1818+
}
1819+
if err != nil {
1820+
panic(err)
1821+
}
1822+
case <-killCh:
1823+
go func() {
1824+
proxyConn.Write([]byte{
1825+
0x08, // packet size
1826+
0x00,
1827+
0x00,
1828+
0x00, // sequence 0
1829+
0xFF, // err_packet
1830+
0x87, // ER_CONNECTION_KILLED error
1831+
0x07,
1832+
0x00, // sql_state_marker
1833+
})
1834+
}()
1835+
}
1836+
}
1837+
}()
1838+
return clientConn, err
1839+
})
1840+
1841+
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
1842+
if err != nil {
1843+
t.Fatalf("error connecting: %s", err.Error())
1844+
}
1845+
defer db.Close()
1846+
1847+
if _, err = db.Exec("DO 1"); err != nil {
1848+
t.Fatalf("connection failed: %s", err.Error())
1849+
}
1850+
1851+
killCh <- struct{}{}
1852+
1853+
if _, err = db.Exec("DO 1"); err != nil {
1854+
t.Fatalf("connection failed: %s", err.Error())
1855+
}
1856+
}
1857+
17761858
func TestSQLInjection(t *testing.T) {
17771859
createTest := func(arg string) func(dbt *DBTest) {
17781860
return func(dbt *DBTest) {

packets.go

+15-1
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,15 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
4646
if data[3] > mc.sequence {
4747
return nil, ErrPktSyncMul
4848
}
49-
return nil, ErrPktSync
49+
50+
// The MariaDB server sends an error packet with sequence numer 0 during
51+
// server shutdown. Continue to process it so the specific error can be
52+
// detected.
53+
if data[3] != 0 {
54+
return nil, ErrPktSync
55+
}
5056
}
57+
5158
mc.sequence++
5259

5360
// packets with length 0 terminate a previous packet which is a
@@ -585,6 +592,13 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
585592
pos = 9
586593
}
587594

595+
// If error code is for ER_CONNECTION_KILLED, then mark connection as bad.
596+
// https://mariadb.com/kb/en/mariadb/mariadb-error-codes/
597+
if errno == 1927 {
598+
errLog.Print("Error ", errno, ": ", string(data[pos:]))
599+
return driver.ErrBadConn
600+
}
601+
588602
// Error Message [string]
589603
return &MySQLError{
590604
Number: errno,

packets_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ func TestReadPacketWrongSequenceID(t *testing.T) {
115115
}
116116

117117
// too low sequence id
118-
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
118+
conn.data = []byte{0x01, 0x00, 0x00, 0x01, 0xff}
119119
conn.maxReads = 1
120-
mc.sequence = 1
120+
mc.sequence = 2
121121
_, err := mc.readPacket()
122122
if err != ErrPktSync {
123123
t.Errorf("expected ErrPktSync, got %v", err)

0 commit comments

Comments
 (0)