@@ -1773,6 +1773,88 @@ func TestCustomDial(t *testing.T) {
1773
1773
}
1774
1774
}
1775
1775
1776
+ // Ensure mariadb's ER_CONNECTION_KILLED will cause the query to be restarted
1777
+ func TestConnectionLost (t * testing.T ) {
1778
+ if ! available {
1779
+ t .Skipf ("MySQL server not running on %s" , netAddr )
1780
+ }
1781
+
1782
+ var proxyConn net.Conn
1783
+
1784
+ killCh := make (chan struct {})
1785
+
1786
+ // our custom dial function which justs wraps net.Dial here
1787
+ RegisterDial ("mydial" , func (addr string ) (net.Conn , error ) {
1788
+ conn , err := net .Dial (prot , addr )
1789
+ if err != nil {
1790
+ return nil , err
1791
+ }
1792
+
1793
+ var clientConn net.Conn
1794
+ proxyConn , clientConn = net .Pipe ()
1795
+ go io .Copy (conn , proxyConn )
1796
+
1797
+ bytesCh := make (chan []byte )
1798
+ go func () {
1799
+ for {
1800
+ bs := make ([]byte , 1024 )
1801
+ n , err := conn .Read (bs )
1802
+ if err == io .EOF {
1803
+ return
1804
+ }
1805
+ if err != nil {
1806
+ panic (err )
1807
+ }
1808
+ bytesCh <- bs [:n ]
1809
+ }
1810
+ }()
1811
+ go func () {
1812
+ for {
1813
+ select {
1814
+ case bs := <- bytesCh :
1815
+ _ , err := proxyConn .Write (bs )
1816
+ if err == io .ErrClosedPipe {
1817
+ return
1818
+ }
1819
+ if err != nil {
1820
+ panic (err )
1821
+ }
1822
+ case <- killCh :
1823
+ go func () {
1824
+ proxyConn .Write ([]byte {
1825
+ 0x08 , // packet size
1826
+ 0x00 ,
1827
+ 0x00 ,
1828
+ 0x00 , // sequence 0
1829
+ 0xFF , // err_packet
1830
+ 0x87 , // ER_CONNECTION_KILLED error
1831
+ 0x07 ,
1832
+ 0x00 , // sql_state_marker
1833
+ })
1834
+ }()
1835
+ }
1836
+ }
1837
+ }()
1838
+ return clientConn , err
1839
+ })
1840
+
1841
+ db , err := sql .Open ("mysql" , fmt .Sprintf ("%s:%s@mydial(%s)/%s?timeout=30s&strict=true" , user , pass , addr , dbname ))
1842
+ if err != nil {
1843
+ t .Fatalf ("error connecting: %s" , err .Error ())
1844
+ }
1845
+ defer db .Close ()
1846
+
1847
+ if _ , err = db .Exec ("DO 1" ); err != nil {
1848
+ t .Fatalf ("connection failed: %s" , err .Error ())
1849
+ }
1850
+
1851
+ killCh <- struct {}{}
1852
+
1853
+ if _ , err = db .Exec ("DO 1" ); err != nil {
1854
+ t .Fatalf ("connection failed: %s" , err .Error ())
1855
+ }
1856
+ }
1857
+
1776
1858
func TestSQLInjection (t * testing.T ) {
1777
1859
createTest := func (arg string ) func (dbt * DBTest ) {
1778
1860
return func (dbt * DBTest ) {
0 commit comments