Skip to content

Commit 1df26e2

Browse files
author
安佳玮
committed
in function MySQLDriver.Open: replace errBadConnNoWrite with driver.ErrBadConn for resend while 'bad connection' happenning
1 parent 361f66e commit 1df26e2

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

driver.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
123123
}
124124
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
125125
mc.cleanup()
126-
return nil, err
126+
return nil, mc.markBadConn(err)
127127
}
128128

129129
// Handle response to auth packet, switch methods if possible

driver_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -2860,3 +2860,19 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
28602860
// This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
28612861
})
28622862
}
2863+
2864+
func TestWriteHandshakeResponseErr(t *testing.T) {
2865+
oldWritter := connWritter
2866+
defer func() {
2867+
connWritter = oldWritter
2868+
}()
2869+
connWritter = func(conn net.Conn, data []byte) (int, error) {
2870+
return 0, fmt.Errorf("network error")
2871+
}
2872+
2873+
md := MySQLDriver{}
2874+
_, err := md.Open(dsn)
2875+
if err != driver.ErrBadConn {
2876+
t.Fatalf("error is not driver.ErrBadConn: %v", err)
2877+
}
2878+
}

packets.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@ import (
1717
"fmt"
1818
"io"
1919
"math"
20+
"net"
2021
"time"
2122
)
2223

24+
// connWritter write data with net.Conn, for test mocking
25+
var connWritter = func(conn net.Conn, data []byte) (int, error) {
26+
return conn.Write(data)
27+
}
28+
2329
// Packets documentation:
2430
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
2531

@@ -118,7 +124,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
118124
}
119125
}
120126

121-
n, err := mc.netConn.Write(data[:4+size])
127+
n, err := connWritter(mc.netConn, data[:4+size])
122128
if err == nil && n == 4+size {
123129
mc.sequence++
124130
if size != maxPacketSize {

0 commit comments

Comments
 (0)