diff --git a/driver.go b/driver.go index ba1297825..a881894aa 100644 --- a/driver.go +++ b/driver.go @@ -123,7 +123,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { mc.cleanup() - return nil, err + return nil, mc.markBadConn(err) } // Handle response to auth packet, switch methods if possible diff --git a/driver_test.go b/driver_test.go index f2bf344e5..3bc27f653 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2860,3 +2860,19 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() }) } + +func TestWriteHandshakeResponseErr(t *testing.T) { + oldWritter := connWritter + defer func() { + connWritter = oldWritter + }() + connWritter = func(conn net.Conn, data []byte) (int, error) { + return 0, fmt.Errorf("network error") + } + + md := MySQLDriver{} + _, err := md.Open(dsn) + if err != driver.ErrBadConn { + t.Fatalf("error is not driver.ErrBadConn: %v", err) + } +} diff --git a/packets.go b/packets.go index 170aaa02b..276b1845c 100644 --- a/packets.go +++ b/packets.go @@ -17,9 +17,15 @@ import ( "fmt" "io" "math" + "net" "time" ) +// connWritter write data with net.Conn, for test mocking +var connWritter = func(conn net.Conn, data []byte) (int, error) { + return conn.Write(data) +} + // Packets documentation: // http://dev.mysql.com/doc/internals/en/client-server-protocol.html @@ -118,7 +124,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { } } - n, err := mc.netConn.Write(data[:4+size]) + n, err := connWritter(mc.netConn, data[:4+size]) if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize {