Skip to content

Commit 91ad4fb

Browse files
authored
Specify a custom dial function per config (#1527)
Specify a custom dial function per config instead of using RegisterDialContext.
1 parent 00dc21a commit 91ad4fb

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

Diff for: connector.go

+18-13
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,25 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
8787
mc.parseTime = mc.cfg.ParseTime
8888

8989
// Connect to Server
90-
dialsLock.RLock()
91-
dial, ok := dials[mc.cfg.Net]
92-
dialsLock.RUnlock()
93-
if ok {
94-
dctx := ctx
95-
if mc.cfg.Timeout > 0 {
96-
var cancel context.CancelFunc
97-
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
98-
defer cancel()
99-
}
100-
mc.netConn, err = dial(dctx, mc.cfg.Addr)
90+
dctx := ctx
91+
if mc.cfg.Timeout > 0 {
92+
var cancel context.CancelFunc
93+
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
94+
defer cancel()
95+
}
96+
97+
if c.cfg.DialFunc != nil {
98+
mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr)
10199
} else {
102-
nd := net.Dialer{Timeout: mc.cfg.Timeout}
103-
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
100+
dialsLock.RLock()
101+
dial, ok := dials[mc.cfg.Net]
102+
dialsLock.RUnlock()
103+
if ok {
104+
mc.netConn, err = dial(dctx, mc.cfg.Addr)
105+
} else {
106+
nd := net.Dialer{}
107+
mc.netConn, err = nd.DialContext(dctx, mc.cfg.Net, mc.cfg.Addr)
108+
}
104109
}
105110
if err != nil {
106111
return nil, err

Diff for: dsn.go

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ type Config struct {
5555
ReadTimeout time.Duration // I/O read timeout
5656
WriteTimeout time.Duration // I/O write timeout
5757
Logger Logger // Logger
58+
// DialFunc specifies the dial function for creating connections
59+
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
5860

5961
// boolean fields
6062

0 commit comments

Comments
 (0)