Skip to content

Commit 89ec2a9

Browse files
Vicent Martímethane
Vicent Martí
authored andcommitted
Support Go 1.10 Connector interface (#941)
1 parent df597a2 commit 89ec2a9

8 files changed

+415
-115
lines changed

appengine.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
package mysql
1212

1313
import (
14+
"context"
15+
1416
"google.golang.org/appengine/cloudsql"
1517
)
1618

1719
func init() {
18-
RegisterDial("cloudsql", cloudsql.Dial)
20+
RegisterDialContext("cloudsql", func(_ context.Context, instance addr) (net.Conn, error) {
21+
// XXX: the cloudsql driver still does not export a Context-aware dialer.
22+
return cloudsql.Dial(instance)
23+
})
1924
}

connector.go

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2+
//
3+
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
4+
//
5+
// This Source Code Form is subject to the terms of the Mozilla Public
6+
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7+
// You can obtain one at http://mozilla.org/MPL/2.0/.
8+
9+
package mysql
10+
11+
import (
12+
"context"
13+
"database/sql/driver"
14+
"net"
15+
)
16+
17+
type connector struct {
18+
cfg *Config // immutable private copy.
19+
}
20+
21+
// Connect implements driver.Connector interface.
22+
// Connect returns a connection to the database.
23+
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
24+
var err error
25+
26+
// New mysqlConn
27+
mc := &mysqlConn{
28+
maxAllowedPacket: maxPacketSize,
29+
maxWriteSize: maxPacketSize - 1,
30+
closech: make(chan struct{}),
31+
cfg: c.cfg,
32+
}
33+
mc.parseTime = mc.cfg.ParseTime
34+
35+
// Connect to Server
36+
dialsLock.RLock()
37+
dial, ok := dials[mc.cfg.Net]
38+
dialsLock.RUnlock()
39+
if ok {
40+
mc.netConn, err = dial(ctx, mc.cfg.Addr)
41+
} else {
42+
nd := net.Dialer{Timeout: mc.cfg.Timeout}
43+
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
44+
}
45+
46+
if err != nil {
47+
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
48+
errLog.Print("net.Error from Dial()': ", nerr.Error())
49+
return nil, driver.ErrBadConn
50+
}
51+
return nil, err
52+
}
53+
54+
// Enable TCP Keepalives on TCP connections
55+
if tc, ok := mc.netConn.(*net.TCPConn); ok {
56+
if err := tc.SetKeepAlive(true); err != nil {
57+
// Don't send COM_QUIT before handshake.
58+
mc.netConn.Close()
59+
mc.netConn = nil
60+
return nil, err
61+
}
62+
}
63+
64+
// Call startWatcher for context support (From Go 1.8)
65+
mc.startWatcher()
66+
if err := mc.watchCancel(ctx); err != nil {
67+
return nil, err
68+
}
69+
defer mc.finish()
70+
71+
mc.buf = newBuffer(mc.netConn)
72+
73+
// Set I/O timeouts
74+
mc.buf.timeout = mc.cfg.ReadTimeout
75+
mc.writeTimeout = mc.cfg.WriteTimeout
76+
77+
// Reading Handshake Initialization Packet
78+
authData, plugin, err := mc.readHandshakePacket()
79+
if err != nil {
80+
mc.cleanup()
81+
return nil, err
82+
}
83+
84+
if plugin == "" {
85+
plugin = defaultAuthPlugin
86+
}
87+
88+
// Send Client Authentication Packet
89+
authResp, err := mc.auth(authData, plugin)
90+
if err != nil {
91+
// try the default auth plugin, if using the requested plugin failed
92+
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
93+
plugin = defaultAuthPlugin
94+
authResp, err = mc.auth(authData, plugin)
95+
if err != nil {
96+
mc.cleanup()
97+
return nil, err
98+
}
99+
}
100+
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
101+
mc.cleanup()
102+
return nil, err
103+
}
104+
105+
// Handle response to auth packet, switch methods if possible
106+
if err = mc.handleAuthResult(authData, plugin); err != nil {
107+
// Authentication failed and MySQL has already closed the connection
108+
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
109+
// Do not send COM_QUIT, just cleanup and return the error.
110+
mc.cleanup()
111+
return nil, err
112+
}
113+
114+
if mc.cfg.MaxAllowedPacket > 0 {
115+
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
116+
} else {
117+
// Get max allowed packet size
118+
maxap, err := mc.getSystemVar("max_allowed_packet")
119+
if err != nil {
120+
mc.Close()
121+
return nil, err
122+
}
123+
mc.maxAllowedPacket = stringToInt(maxap) - 1
124+
}
125+
if mc.maxAllowedPacket < maxPacketSize {
126+
mc.maxWriteSize = mc.maxAllowedPacket
127+
}
128+
129+
// Handle DSN Params
130+
err = mc.handleParams()
131+
if err != nil {
132+
mc.Close()
133+
return nil, err
134+
}
135+
136+
return mc, nil
137+
}
138+
139+
// Driver implements driver.Connector interface.
140+
// Driver returns &MySQLDriver{}.
141+
func (c *connector) Driver() driver.Driver {
142+
return &MySQLDriver{}
143+
}

driver.go

+27-111
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package mysql
1818

1919
import (
20+
"context"
2021
"database/sql"
2122
"database/sql/driver"
2223
"net"
@@ -29,139 +30,54 @@ type MySQLDriver struct{}
2930

3031
// DialFunc is a function which can be used to establish the network connection.
3132
// Custom dial functions must be registered with RegisterDial
33+
//
34+
// Deprecated: users should register a DialContextFunc instead
3235
type DialFunc func(addr string) (net.Conn, error)
3336

37+
// DialContextFunc is a function which can be used to establish the network connection.
38+
// Custom dial functions must be registered with RegisterDialContext
39+
type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error)
40+
3441
var (
3542
dialsLock sync.RWMutex
36-
dials map[string]DialFunc
43+
dials map[string]DialContextFunc
3744
)
3845

39-
// RegisterDial registers a custom dial function. It can then be used by the
46+
// RegisterDialContext registers a custom dial function. It can then be used by the
4047
// network address mynet(addr), where mynet is the registered new network.
41-
// addr is passed as a parameter to the dial function.
42-
func RegisterDial(net string, dial DialFunc) {
48+
// The current context for the connection and its address is passed to the dial function.
49+
func RegisterDialContext(net string, dial DialContextFunc) {
4350
dialsLock.Lock()
4451
defer dialsLock.Unlock()
4552
if dials == nil {
46-
dials = make(map[string]DialFunc)
53+
dials = make(map[string]DialContextFunc)
4754
}
4855
dials[net] = dial
4956
}
5057

58+
// RegisterDial registers a custom dial function. It can then be used by the
59+
// network address mynet(addr), where mynet is the registered new network.
60+
// addr is passed as a parameter to the dial function.
61+
//
62+
// Deprecated: users should call RegisterDialContext instead
63+
func RegisterDial(network string, dial DialFunc) {
64+
RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) {
65+
return dial(addr)
66+
})
67+
}
68+
5169
// Open new Connection.
5270
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
5371
// the DSN string is formatted
5472
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
55-
var err error
56-
57-
// New mysqlConn
58-
mc := &mysqlConn{
59-
maxAllowedPacket: maxPacketSize,
60-
maxWriteSize: maxPacketSize - 1,
61-
closech: make(chan struct{}),
62-
}
63-
mc.cfg, err = ParseDSN(dsn)
64-
if err != nil {
65-
return nil, err
66-
}
67-
mc.parseTime = mc.cfg.ParseTime
68-
69-
// Connect to Server
70-
dialsLock.RLock()
71-
dial, ok := dials[mc.cfg.Net]
72-
dialsLock.RUnlock()
73-
if ok {
74-
mc.netConn, err = dial(mc.cfg.Addr)
75-
} else {
76-
nd := net.Dialer{Timeout: mc.cfg.Timeout}
77-
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
78-
}
79-
if err != nil {
80-
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
81-
errLog.Print("net.Error from Dial()': ", nerr.Error())
82-
return nil, driver.ErrBadConn
83-
}
84-
return nil, err
85-
}
86-
87-
// Enable TCP Keepalives on TCP connections
88-
if tc, ok := mc.netConn.(*net.TCPConn); ok {
89-
if err := tc.SetKeepAlive(true); err != nil {
90-
// Don't send COM_QUIT before handshake.
91-
mc.netConn.Close()
92-
mc.netConn = nil
93-
return nil, err
94-
}
95-
}
96-
97-
// Call startWatcher for context support (From Go 1.8)
98-
mc.startWatcher()
99-
100-
mc.buf = newBuffer(mc.netConn)
101-
102-
// Set I/O timeouts
103-
mc.buf.timeout = mc.cfg.ReadTimeout
104-
mc.writeTimeout = mc.cfg.WriteTimeout
105-
106-
// Reading Handshake Initialization Packet
107-
authData, plugin, err := mc.readHandshakePacket()
73+
cfg, err := ParseDSN(dsn)
10874
if err != nil {
109-
mc.cleanup()
11075
return nil, err
11176
}
112-
if plugin == "" {
113-
plugin = defaultAuthPlugin
77+
c := &connector{
78+
cfg: cfg,
11479
}
115-
116-
// Send Client Authentication Packet
117-
authResp, err := mc.auth(authData, plugin)
118-
if err != nil {
119-
// try the default auth plugin, if using the requested plugin failed
120-
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
121-
plugin = defaultAuthPlugin
122-
authResp, err = mc.auth(authData, plugin)
123-
if err != nil {
124-
mc.cleanup()
125-
return nil, err
126-
}
127-
}
128-
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
129-
mc.cleanup()
130-
return nil, err
131-
}
132-
133-
// Handle response to auth packet, switch methods if possible
134-
if err = mc.handleAuthResult(authData, plugin); err != nil {
135-
// Authentication failed and MySQL has already closed the connection
136-
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
137-
// Do not send COM_QUIT, just cleanup and return the error.
138-
mc.cleanup()
139-
return nil, err
140-
}
141-
142-
if mc.cfg.MaxAllowedPacket > 0 {
143-
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
144-
} else {
145-
// Get max allowed packet size
146-
maxap, err := mc.getSystemVar("max_allowed_packet")
147-
if err != nil {
148-
mc.Close()
149-
return nil, err
150-
}
151-
mc.maxAllowedPacket = stringToInt(maxap) - 1
152-
}
153-
if mc.maxAllowedPacket < maxPacketSize {
154-
mc.maxWriteSize = mc.maxAllowedPacket
155-
}
156-
157-
// Handle DSN Params
158-
err = mc.handleParams()
159-
if err != nil {
160-
mc.Close()
161-
return nil, err
162-
}
163-
164-
return mc, nil
80+
return c.Connect(context.Background())
16581
}
16682

16783
func init() {

driver_go110.go

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2+
//
3+
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
4+
//
5+
// This Source Code Form is subject to the terms of the Mozilla Public
6+
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7+
// You can obtain one at http://mozilla.org/MPL/2.0/.
8+
9+
// +build go1.10
10+
11+
package mysql
12+
13+
import (
14+
"database/sql/driver"
15+
)
16+
17+
// NewConnector returns new driver.Connector.
18+
func NewConnector(cfg *Config) (driver.Connector, error) {
19+
cfg = cfg.Clone()
20+
// normalize the contents of cfg so calls to NewConnector have the same
21+
// behavior as MySQLDriver.OpenConnector
22+
if err := cfg.normalize(); err != nil {
23+
return nil, err
24+
}
25+
return &connector{cfg: cfg}, nil
26+
}
27+
28+
// OpenConnector implements driver.DriverContext.
29+
func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
30+
cfg, err := ParseDSN(dsn)
31+
if err != nil {
32+
return nil, err
33+
}
34+
return &connector{
35+
cfg: cfg,
36+
}, nil
37+
}

0 commit comments

Comments
 (0)