@@ -21,6 +21,7 @@ import (
21
21
"database/sql"
22
22
"database/sql/driver"
23
23
"net"
24
+ pkgnet "net"
24
25
"sync"
25
26
)
26
27
@@ -32,19 +33,33 @@ type MySQLDriver struct{}
32
33
// Custom dial functions must be registered with RegisterDial
33
34
type DialFunc func (addr string ) (net.Conn , error )
34
35
36
+ // DialContextFunc is a function which can be used to establish the network connection using the provided context.
37
+ // Custom dial functions must be registered with RegisterDialContext
38
+ type DialContextFunc func (ctx context.Context , addr string ) (net.Conn , error )
39
+
35
40
var (
36
41
dialsLock sync.RWMutex
37
- dials map [string ]DialFunc
42
+ dials map [string ]DialContextFunc
38
43
)
39
44
40
45
// RegisterDial registers a custom dial function. It can then be used by the
41
46
// network address mynet(addr), where mynet is the registered new network.
42
47
// addr is passed as a parameter to the dial function.
43
48
func RegisterDial (net string , dial DialFunc ) {
49
+ dialContext := DialContextFunc (func (ctx context.Context , addr string ) (pkgnet.Conn , error ) {
50
+ return dial (addr )
51
+ })
52
+ RegisterDialContext (net , dialContext )
53
+ }
54
+
55
+ // RegisterDialContext registers a custom dial function. It can then be used by the
56
+ // network address mynet(addr), where mynet is the registered new network.
57
+ // addr is passed as a parameter to the dial function.
58
+ func RegisterDialContext (net string , dial DialContextFunc ) {
44
59
dialsLock .Lock ()
45
60
defer dialsLock .Unlock ()
46
61
if dials == nil {
47
- dials = make (map [string ]DialFunc )
62
+ dials = make (map [string ]DialContextFunc )
48
63
}
49
64
dials [net ] = dial
50
65
}
0 commit comments