@@ -85,6 +85,23 @@ type DBTest struct {
85
85
db * sql.DB
86
86
}
87
87
88
+ type netErrorMock struct {
89
+ temporary bool
90
+ timeout bool
91
+ }
92
+
93
+ func (e netErrorMock ) Temporary () bool {
94
+ return e .temporary
95
+ }
96
+
97
+ func (e netErrorMock ) Timeout () bool {
98
+ return e .timeout
99
+ }
100
+
101
+ func (e netErrorMock ) Error () string {
102
+ return fmt .Sprintf ("mock net error. Temporary: %v, Timeout %v" , e .temporary , e .timeout )
103
+ }
104
+
88
105
func runTestsWithMultiStatement (t * testing.T , dsn string , tests ... func (dbt * DBTest )) {
89
106
if ! available {
90
107
t .Skipf ("MySQL server not running on %s" , netAddr )
@@ -1801,6 +1818,38 @@ func TestConcurrent(t *testing.T) {
1801
1818
})
1802
1819
}
1803
1820
1821
+ func testDialError (t * testing.T , dialErr error , expectErr error ) {
1822
+ RegisterDial ("mydial" , func (addr string ) (net.Conn , error ) {
1823
+ return nil , dialErr
1824
+ })
1825
+
1826
+ db , err := sql .Open ("mysql" , fmt .Sprintf ("%s:%s@mydial(%s)/%s?timeout=30s" , user , pass , addr , dbname ))
1827
+ if err != nil {
1828
+ t .Fatalf ("error connecting: %s" , err .Error ())
1829
+ }
1830
+ defer db .Close ()
1831
+
1832
+ _ , err = db .Exec ("DO 1" )
1833
+ if err != expectErr {
1834
+ t .Fatalf ("was expecting %s. Got: %s" , dialErr , err )
1835
+ }
1836
+ }
1837
+
1838
+ func TestDialUnknownError (t * testing.T ) {
1839
+ testErr := fmt .Errorf ("test" )
1840
+ testDialError (t , testErr , testErr )
1841
+ }
1842
+
1843
+ func TestDialNonRetryableNetErr (t * testing.T ) {
1844
+ testErr := netErrorMock {}
1845
+ testDialError (t , testErr , testErr )
1846
+ }
1847
+
1848
+ func TestDialTemporaryNetErr (t * testing.T ) {
1849
+ testErr := netErrorMock {temporary : true }
1850
+ testDialError (t , testErr , driver .ErrBadConn )
1851
+ }
1852
+
1804
1853
// Tests custom dial functions
1805
1854
func TestCustomDial (t * testing.T ) {
1806
1855
if ! available {
0 commit comments