Skip to content

Commit fd197cd

Browse files
tjenkinsonmethane
authored andcommitted
Return ErrBadConn for temporary Dial error (#867)
When `Dial()` returned error and it's `Timeout() == true`, return ErrBadConn to database/sql retry new connection.
1 parent 64cea2f commit fd197cd

File tree

3 files changed

+54
-0
lines changed

3 files changed

+54
-0
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ Soroush Pour <me at soroushjp.com>
7474
Stan Putrya <root.vagner at gmail.com>
7575
Stanley Gunawan <gunawan.stanley at gmail.com>
7676
Thomas Wodarek <wodarekwebpage at gmail.com>
77+
Tom Jenkinson <tom at tjenkinson.me>
7778
Xiangyu Hu <xiangyu.hu at outlook.com>
7879
Xiaobing Jiang <s7v7nislands at gmail.com>
7980
Xiuming Chen <cc at cxm.cc>

driver.go

+4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
7777
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
7878
}
7979
if err != nil {
80+
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
81+
errLog.Print("net.Error from Dial()': ", nerr.Error())
82+
return nil, driver.ErrBadConn
83+
}
8084
return nil, err
8185
}
8286

driver_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,23 @@ type DBTest struct {
8585
db *sql.DB
8686
}
8787

88+
type netErrorMock struct {
89+
temporary bool
90+
timeout bool
91+
}
92+
93+
func (e netErrorMock) Temporary() bool {
94+
return e.temporary
95+
}
96+
97+
func (e netErrorMock) Timeout() bool {
98+
return e.timeout
99+
}
100+
101+
func (e netErrorMock) Error() string {
102+
return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
103+
}
104+
88105
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
89106
if !available {
90107
t.Skipf("MySQL server not running on %s", netAddr)
@@ -1801,6 +1818,38 @@ func TestConcurrent(t *testing.T) {
18011818
})
18021819
}
18031820

1821+
func testDialError(t *testing.T, dialErr error, expectErr error) {
1822+
RegisterDial("mydial", func(addr string) (net.Conn, error) {
1823+
return nil, dialErr
1824+
})
1825+
1826+
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
1827+
if err != nil {
1828+
t.Fatalf("error connecting: %s", err.Error())
1829+
}
1830+
defer db.Close()
1831+
1832+
_, err = db.Exec("DO 1")
1833+
if err != expectErr {
1834+
t.Fatalf("was expecting %s. Got: %s", dialErr, err)
1835+
}
1836+
}
1837+
1838+
func TestDialUnknownError(t *testing.T) {
1839+
testErr := fmt.Errorf("test")
1840+
testDialError(t, testErr, testErr)
1841+
}
1842+
1843+
func TestDialNonRetryableNetErr(t *testing.T) {
1844+
testErr := netErrorMock{}
1845+
testDialError(t, testErr, testErr)
1846+
}
1847+
1848+
func TestDialTemporaryNetErr(t *testing.T) {
1849+
testErr := netErrorMock{temporary: true}
1850+
testDialError(t, testErr, driver.ErrBadConn)
1851+
}
1852+
18041853
// Tests custom dial functions
18051854
func TestCustomDial(t *testing.T) {
18061855
if !available {

0 commit comments

Comments
 (0)