Skip to content

Commit 89cc76d

Browse files
committed
implement RegisterDialContext.
1 parent af9889e commit 89cc76d

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

connector.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,11 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
3333
mc.parseTime = mc.cfg.ParseTime
3434

3535
// Connect to Server
36-
// TODO: needs RegisterDialContext
3736
dialsLock.RLock()
3837
dial, ok := dials[mc.cfg.Net]
3938
dialsLock.RUnlock()
4039
if ok {
41-
mc.netConn, err = dial(mc.cfg.Addr)
40+
mc.netConn, err = dial(ctx, mc.cfg.Addr)
4241
} else {
4342
nd := net.Dialer{Timeout: mc.cfg.Timeout}
4443
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)

driver.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"database/sql"
2222
"database/sql/driver"
2323
"net"
24+
pkgnet "net"
2425
"sync"
2526
)
2627

@@ -32,19 +33,33 @@ type MySQLDriver struct{}
3233
// Custom dial functions must be registered with RegisterDial
3334
type DialFunc func(addr string) (net.Conn, error)
3435

36+
// DialContextFunc is a function which can be used to establish the network connection using the provided context.
37+
// Custom dial functions must be registered with RegisterDialContext
38+
type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error)
39+
3540
var (
3641
dialsLock sync.RWMutex
37-
dials map[string]DialFunc
42+
dials map[string]DialContextFunc
3843
)
3944

4045
// RegisterDial registers a custom dial function. It can then be used by the
4146
// network address mynet(addr), where mynet is the registered new network.
4247
// addr is passed as a parameter to the dial function.
4348
func RegisterDial(net string, dial DialFunc) {
49+
dialContext := DialContextFunc(func(ctx context.Context, addr string) (pkgnet.Conn, error) {
50+
return dial(addr)
51+
})
52+
RegisterDialContext(net, dialContext)
53+
}
54+
55+
// RegisterDialContext registers a custom dial function. It can then be used by the
56+
// network address mynet(addr), where mynet is the registered new network.
57+
// addr is passed as a parameter to the dial function.
58+
func RegisterDialContext(net string, dial DialContextFunc) {
4459
dialsLock.Lock()
4560
defer dialsLock.Unlock()
4661
if dials == nil {
47-
dials = make(map[string]DialFunc)
62+
dials = make(map[string]DialContextFunc)
4863
}
4964
dials[net] = dial
5065
}

0 commit comments

Comments
 (0)