diff --git a/AUTHORS b/AUTHORS index fbe4ec442..1c37f1ee8 100644 --- a/AUTHORS +++ b/AUTHORS @@ -73,6 +73,7 @@ Soroush Pour Stan Putrya Stanley Gunawan Thomas Wodarek +Tom Jenkinson Xiangyu Hu Xiaobing Jiang Xiuming Chen diff --git a/driver.go b/driver.go index ba1297825..eeb83df01 100644 --- a/driver.go +++ b/driver.go @@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) } if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + errLog.Print("net.Error from Dial()': ", nerr.Error()) + return nil, driver.ErrBadConn + } return nil, err } diff --git a/driver_test.go b/driver_test.go index f2bf344e5..cec4b5867 100644 --- a/driver_test.go +++ b/driver_test.go @@ -85,6 +85,23 @@ type DBTest struct { db *sql.DB } +type netErrorMock struct { + temporary bool + timeout bool +} + +func (e netErrorMock) Temporary() bool { + return e.temporary +} + +func (e netErrorMock) Timeout() bool { + return e.timeout +} + +func (e netErrorMock) Error() string { + return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout) +} + func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if !available { t.Skipf("MySQL server not running on %s", netAddr) @@ -1801,6 +1818,38 @@ func TestConcurrent(t *testing.T) { }) } +func testDialError(t *testing.T, dialErr error, expectErr error) { + RegisterDial("mydial", func(addr string) (net.Conn, error) { + return nil, dialErr + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + _, err = db.Exec("DO 1") + if err != expectErr { + t.Fatalf("was expecting %s. Got: %s", dialErr, err) + } +} + +func TestDialUnknownError(t *testing.T) { + testErr := fmt.Errorf("test") + testDialError(t, testErr, testErr) +} + +func TestDialNonRetryableNetErr(t *testing.T) { + testErr := netErrorMock{} + testDialError(t, testErr, testErr) +} + +func TestDialTemporaryNetErr(t *testing.T) { + testErr := netErrorMock{temporary: true} + testDialError(t, testErr, driver.ErrBadConn) +} + // Tests custom dial functions func TestCustomDial(t *testing.T) { if !available {