Skip to content

Commit 10f45a3

Browse files
committed
Specify a custom dial function per config
1 parent 3147497 commit 10f45a3

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

connector.go

+17-7
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,30 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
8787
mc.parseTime = mc.cfg.ParseTime
8888

8989
// Connect to Server
90-
dialsLock.RLock()
91-
dial, ok := dials[mc.cfg.Net]
92-
dialsLock.RUnlock()
93-
if ok {
90+
if c.cfg.DialFunc != nil {
9491
dctx := ctx
9592
if mc.cfg.Timeout > 0 {
9693
var cancel context.CancelFunc
9794
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
9895
defer cancel()
9996
}
100-
mc.netConn, err = dial(dctx, mc.cfg.Addr)
97+
mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr)
10198
} else {
102-
nd := net.Dialer{Timeout: mc.cfg.Timeout}
103-
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
99+
dialsLock.RLock()
100+
dial, ok := dials[mc.cfg.Net]
101+
dialsLock.RUnlock()
102+
if ok {
103+
dctx := ctx
104+
if mc.cfg.Timeout > 0 {
105+
var cancel context.CancelFunc
106+
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
107+
defer cancel()
108+
}
109+
mc.netConn, err = dial(dctx, mc.cfg.Addr)
110+
} else {
111+
nd := net.Dialer{Timeout: mc.cfg.Timeout}
112+
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
113+
}
104114
}
105115

106116
if err != nil {

dsn.go

+18-17
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,24 @@ var (
3737
type Config struct {
3838
// non boolean fields
3939

40-
User string // Username
41-
Passwd string // Password (requires User)
42-
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
43-
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
44-
DBName string // Database name
45-
Params map[string]string // Connection parameters
46-
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
47-
Collation string // Connection collation
48-
Loc *time.Location // Location for time.Time values
49-
MaxAllowedPacket int // Max packet size allowed
50-
ServerPubKey string // Server public key name
51-
TLSConfig string // TLS configuration name
52-
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
53-
Timeout time.Duration // Dial timeout
54-
ReadTimeout time.Duration // I/O read timeout
55-
WriteTimeout time.Duration // I/O write timeout
56-
Logger Logger // Logger
40+
User string // Username
41+
Passwd string // Password (requires User)
42+
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
43+
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
44+
DBName string // Database name
45+
Params map[string]string // Connection parameters
46+
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
47+
Collation string // Connection collation
48+
Loc *time.Location // Location for time.Time values
49+
MaxAllowedPacket int // Max packet size allowed
50+
ServerPubKey string // Server public key name
51+
TLSConfig string // TLS configuration name
52+
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
53+
Timeout time.Duration // Dial timeout
54+
ReadTimeout time.Duration // I/O read timeout
55+
WriteTimeout time.Duration // I/O write timeout
56+
Logger Logger // Logger
57+
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // Specifies the dial function for creating connections
5758

5859
// boolean fields
5960

0 commit comments

Comments
 (0)