17
17
package mysql
18
18
19
19
import (
20
+ "context"
20
21
"database/sql"
21
22
"database/sql/driver"
22
23
"net"
@@ -29,139 +30,54 @@ type MySQLDriver struct{}
29
30
30
31
// DialFunc is a function which can be used to establish the network connection.
31
32
// Custom dial functions must be registered with RegisterDial
33
+ //
34
+ // Deprecated: users should register a DialContextFunc instead
32
35
type DialFunc func (addr string ) (net.Conn , error )
33
36
37
+ // DialContextFunc is a function which can be used to establish the network connection.
38
+ // Custom dial functions must be registered with RegisterDialContext
39
+ type DialContextFunc func (ctx context.Context , addr string ) (net.Conn , error )
40
+
34
41
var (
35
42
dialsLock sync.RWMutex
36
- dials map [string ]DialFunc
43
+ dials map [string ]DialContextFunc
37
44
)
38
45
39
- // RegisterDial registers a custom dial function. It can then be used by the
46
+ // RegisterDialContext registers a custom dial function. It can then be used by the
40
47
// network address mynet(addr), where mynet is the registered new network.
41
- // addr is passed as a parameter to the dial function.
42
- func RegisterDial (net string , dial DialFunc ) {
48
+ // The current context for the connection and its address is passed to the dial function.
49
+ func RegisterDialContext (net string , dial DialContextFunc ) {
43
50
dialsLock .Lock ()
44
51
defer dialsLock .Unlock ()
45
52
if dials == nil {
46
- dials = make (map [string ]DialFunc )
53
+ dials = make (map [string ]DialContextFunc )
47
54
}
48
55
dials [net ] = dial
49
56
}
50
57
58
+ // RegisterDial registers a custom dial function. It can then be used by the
59
+ // network address mynet(addr), where mynet is the registered new network.
60
+ // addr is passed as a parameter to the dial function.
61
+ //
62
+ // Deprecated: users should call RegisterDialContext instead
63
+ func RegisterDial (network string , dial DialFunc ) {
64
+ RegisterDialContext (network , func (_ context.Context , addr string ) (net.Conn , error ) {
65
+ return dial (addr )
66
+ })
67
+ }
68
+
51
69
// Open new Connection.
52
70
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
53
71
// the DSN string is formatted
54
72
func (d MySQLDriver ) Open (dsn string ) (driver.Conn , error ) {
55
- var err error
56
-
57
- // New mysqlConn
58
- mc := & mysqlConn {
59
- maxAllowedPacket : maxPacketSize ,
60
- maxWriteSize : maxPacketSize - 1 ,
61
- closech : make (chan struct {}),
62
- }
63
- mc .cfg , err = ParseDSN (dsn )
64
- if err != nil {
65
- return nil , err
66
- }
67
- mc .parseTime = mc .cfg .ParseTime
68
-
69
- // Connect to Server
70
- dialsLock .RLock ()
71
- dial , ok := dials [mc .cfg .Net ]
72
- dialsLock .RUnlock ()
73
- if ok {
74
- mc .netConn , err = dial (mc .cfg .Addr )
75
- } else {
76
- nd := net.Dialer {Timeout : mc .cfg .Timeout }
77
- mc .netConn , err = nd .Dial (mc .cfg .Net , mc .cfg .Addr )
78
- }
79
- if err != nil {
80
- if nerr , ok := err .(net.Error ); ok && nerr .Temporary () {
81
- errLog .Print ("net.Error from Dial()': " , nerr .Error ())
82
- return nil , driver .ErrBadConn
83
- }
84
- return nil , err
85
- }
86
-
87
- // Enable TCP Keepalives on TCP connections
88
- if tc , ok := mc .netConn .(* net.TCPConn ); ok {
89
- if err := tc .SetKeepAlive (true ); err != nil {
90
- // Don't send COM_QUIT before handshake.
91
- mc .netConn .Close ()
92
- mc .netConn = nil
93
- return nil , err
94
- }
95
- }
96
-
97
- // Call startWatcher for context support (From Go 1.8)
98
- mc .startWatcher ()
99
-
100
- mc .buf = newBuffer (mc .netConn )
101
-
102
- // Set I/O timeouts
103
- mc .buf .timeout = mc .cfg .ReadTimeout
104
- mc .writeTimeout = mc .cfg .WriteTimeout
105
-
106
- // Reading Handshake Initialization Packet
107
- authData , plugin , err := mc .readHandshakePacket ()
73
+ cfg , err := ParseDSN (dsn )
108
74
if err != nil {
109
- mc .cleanup ()
110
75
return nil , err
111
76
}
112
- if plugin == "" {
113
- plugin = defaultAuthPlugin
77
+ c := & connector {
78
+ cfg : cfg ,
114
79
}
115
-
116
- // Send Client Authentication Packet
117
- authResp , err := mc .auth (authData , plugin )
118
- if err != nil {
119
- // try the default auth plugin, if using the requested plugin failed
120
- errLog .Print ("could not use requested auth plugin '" + plugin + "': " , err .Error ())
121
- plugin = defaultAuthPlugin
122
- authResp , err = mc .auth (authData , plugin )
123
- if err != nil {
124
- mc .cleanup ()
125
- return nil , err
126
- }
127
- }
128
- if err = mc .writeHandshakeResponsePacket (authResp , plugin ); err != nil {
129
- mc .cleanup ()
130
- return nil , err
131
- }
132
-
133
- // Handle response to auth packet, switch methods if possible
134
- if err = mc .handleAuthResult (authData , plugin ); err != nil {
135
- // Authentication failed and MySQL has already closed the connection
136
- // (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
137
- // Do not send COM_QUIT, just cleanup and return the error.
138
- mc .cleanup ()
139
- return nil , err
140
- }
141
-
142
- if mc .cfg .MaxAllowedPacket > 0 {
143
- mc .maxAllowedPacket = mc .cfg .MaxAllowedPacket
144
- } else {
145
- // Get max allowed packet size
146
- maxap , err := mc .getSystemVar ("max_allowed_packet" )
147
- if err != nil {
148
- mc .Close ()
149
- return nil , err
150
- }
151
- mc .maxAllowedPacket = stringToInt (maxap ) - 1
152
- }
153
- if mc .maxAllowedPacket < maxPacketSize {
154
- mc .maxWriteSize = mc .maxAllowedPacket
155
- }
156
-
157
- // Handle DSN Params
158
- err = mc .handleParams ()
159
- if err != nil {
160
- mc .Close ()
161
- return nil , err
162
- }
163
-
164
- return mc , nil
80
+ return c .Connect (context .Background ())
165
81
}
166
82
167
83
func init () {
0 commit comments