diff --git a/client/conn.go b/client/conn.go index 7f422f4df..19318ddd8 100644 --- a/client/conn.go +++ b/client/conn.go @@ -1,6 +1,7 @@ package client import ( + "context" "crypto/tls" "fmt" "net" @@ -55,10 +56,23 @@ func getNetProto(addr string) string { func Connect(addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) { proto := getNetProto(addr) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + dialer := &net.Dialer{} + + return ConnectWithDialer(ctx, proto, addr, user, password, dbName, dialer.DialContext, options...) +} + +// Dialer connects to the address on the named network using the provided context. +type Dialer func(ctx context.Context, network, address string) (net.Conn, error) + +// Connect to a MySQL server using the given Dialer. +func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) { c := new(Conn) var err error - conn, err := net.DialTimeout(proto, addr, 10*time.Second) + conn, err := dialer(ctx, network, addr) if err != nil { return nil, errors.Trace(err) } @@ -72,9 +86,9 @@ func Connect(addr string, user string, password string, dbName string, options . c.user = user c.password = password c.db = dbName - c.proto = proto + c.proto = network - //use default charset here, utf-8 + // use default charset here, utf-8 c.charset = DEFAULT_CHARSET // Apply configuration functions.