Skip to content
This repository was archived by the owner on Jan 21, 2022. It is now read-only.

Commit 27dbdf1

Browse files
crhinojvshahid
authored andcommitted
Add a test to ensure queries are retried on mariadb's ER_CONNECTION_KILLED
Signed-off-by: John Shahid <[email protected]>
1 parent 8fefef0 commit 27dbdf1

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

AUTHORS

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Asta Xie <xiemengjun at gmail.com>
1818
Bulat Gaifullin <gaifullinbf at gmail.com>
1919
Carlos Nieto <jose.carlos at menteslibres.net>
2020
Chris Moos <chris at tech9computers.com>
21+
Chris Piraino <piraino.chris at gmail.com>
2122
Daniel Nichter <nil at codenode.com>
2223
Daniël van Eeden <git at myname.nl>
2324
Dave Protasowski <dprotaso at gmail.com>
@@ -34,6 +35,7 @@ INADA Naoki <songofacandy at gmail.com>
3435
Jacek Szwec <szwec.jacek at gmail.com>
3536
James Harr <james.harr at gmail.com>
3637
Jian Zhen <zhenjl at gmail.com>
38+
John Shahid <jvshahid at gmail.com>
3739
Joshua Prunier <joshua.prunier at gmail.com>
3840
Julien Lefevre <julien.lefevr at gmail.com>
3941
Julien Schmidt <go-sql-driver at julienschmidt.com>

driver_test.go

+82
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,88 @@ func TestCustomDial(t *testing.T) {
17731773
}
17741774
}
17751775

1776+
// Ensure mariadb's ER_CONNECTION_KILLED will cause the query to be restarted
1777+
func TestConnectionLost(t *testing.T) {
1778+
if !available {
1779+
t.Skipf("MySQL server not running on %s", netAddr)
1780+
}
1781+
1782+
var proxyConn net.Conn
1783+
1784+
killCh := make(chan struct{})
1785+
1786+
// our custom dial function which justs wraps net.Dial here
1787+
RegisterDial("mydial", func(addr string) (net.Conn, error) {
1788+
conn, err := net.Dial(prot, addr)
1789+
if err != nil {
1790+
return nil, err
1791+
}
1792+
1793+
var clientConn net.Conn
1794+
proxyConn, clientConn = net.Pipe()
1795+
go io.Copy(conn, proxyConn)
1796+
1797+
bytesCh := make(chan []byte)
1798+
go func() {
1799+
for {
1800+
bs := make([]byte, 1024)
1801+
n, err := conn.Read(bs)
1802+
if err == io.EOF {
1803+
return
1804+
}
1805+
if err != nil {
1806+
panic(err)
1807+
}
1808+
bytesCh <- bs[:n]
1809+
}
1810+
}()
1811+
go func() {
1812+
for {
1813+
select {
1814+
case bs := <-bytesCh:
1815+
_, err := proxyConn.Write(bs)
1816+
if err == io.ErrClosedPipe {
1817+
return
1818+
}
1819+
if err != nil {
1820+
panic(err)
1821+
}
1822+
case <-killCh:
1823+
go func() {
1824+
proxyConn.Write([]byte{
1825+
0x08, // packet size
1826+
0x00,
1827+
0x00,
1828+
0x00, // sequence 0
1829+
0xFF, // err_packet
1830+
0x87, // ER_CONNECTION_KILLED error
1831+
0x07,
1832+
0x00, // sql_state_marker
1833+
})
1834+
}()
1835+
}
1836+
}
1837+
}()
1838+
return clientConn, err
1839+
})
1840+
1841+
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
1842+
if err != nil {
1843+
t.Fatalf("error connecting: %s", err.Error())
1844+
}
1845+
defer db.Close()
1846+
1847+
if _, err = db.Exec("DO 1"); err != nil {
1848+
t.Fatalf("connection failed: %s", err.Error())
1849+
}
1850+
1851+
killCh <- struct{}{}
1852+
1853+
if _, err = db.Exec("DO 1"); err != nil {
1854+
t.Fatalf("connection failed: %s", err.Error())
1855+
}
1856+
}
1857+
17761858
func TestSQLInjection(t *testing.T) {
17771859
createTest := func(arg string) func(dbt *DBTest) {
17781860
return func(dbt *DBTest) {

0 commit comments

Comments
 (0)