Skip to content

Commit ad744fc

Browse files
committed
A way to register preferred TLS config
Signed-off-by: lance6716 <[email protected]>
1 parent fa1e4ed commit ad744fc

File tree

4 files changed

+95
-7
lines changed

4 files changed

+95
-7
lines changed

dsn.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ type Config struct {
4747
pubKey *rsa.PublicKey // Server public key
4848
TLSConfig string // TLS configuration name
4949
tls *tls.Config // TLS configuration
50+
tlsIsPreferred bool // TLS is preferred, not a must
5051
Timeout time.Duration // Dial timeout
5152
ReadTimeout time.Duration // I/O read timeout
5253
WriteTimeout time.Duration // I/O write timeout
@@ -124,10 +125,13 @@ func (cfg *Config) normalize() error {
124125
// don't set anything
125126
case "true":
126127
cfg.tls = &tls.Config{}
127-
case "skip-verify", "preferred":
128+
case "skip-verify":
128129
cfg.tls = &tls.Config{InsecureSkipVerify: true}
130+
case "preferred":
131+
cfg.tls = &tls.Config{InsecureSkipVerify: true}
132+
cfg.tlsIsPreferred = true
129133
default:
130-
cfg.tls = getTLSConfigClone(cfg.TLSConfig)
134+
cfg.tls, cfg.tlsIsPreferred = getTLSConfigCloneAndPreferred(cfg.TLSConfig)
131135
if cfg.tls == nil {
132136
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
133137
}

dsn_test.go

+60
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,66 @@ func TestDSNWithCustomTLS(t *testing.T) {
221221
}
222222
}
223223

224+
func TestRegisterPreferredTLSConfig(t *testing.T) {
225+
tlsMust := "tls-must"
226+
tlsPreferred := "tls-preferred"
227+
tlsCfg := tls.Config{ServerName: tlsMust}
228+
tlsCfg2 := tls.Config{ServerName: tlsPreferred}
229+
230+
err := RegisterTLSConfig(tlsMust, &tlsCfg)
231+
if err != nil {
232+
t.Error(err.Error())
233+
}
234+
defer DeregisterTLSConfig(tlsMust)
235+
err = RegisterPreferredTLSConfig(tlsPreferred, &tlsCfg2)
236+
if err != nil {
237+
t.Error(err.Error())
238+
}
239+
defer DeregisterTLSConfig(tlsPreferred)
240+
241+
checkCfgAndPreferred := func(key, expectedName string, expectedPreferred bool) {
242+
cfg, preferred := getTLSConfigCloneAndPreferred(key)
243+
if cfg.ServerName != expectedName {
244+
t.Errorf("did not get the correct TLS ServerName (%s): %s.", expectedName, cfg.ServerName)
245+
}
246+
if expectedPreferred && !preferred {
247+
t.Errorf("this should not be a preferred TLS config: %s", key)
248+
}
249+
if !expectedPreferred && preferred {
250+
t.Errorf("this should be a preferred TLS config: %s", key)
251+
}
252+
}
253+
254+
checkCfgAndPreferred(tlsMust, tlsMust, false)
255+
checkCfgAndPreferred(tlsPreferred, tlsPreferred, true)
256+
257+
// test RegisterTLSConfig overwrites existing config and preferred
258+
tlsAnother := "tls-another"
259+
tlsCfg3 := tls.Config{ServerName: tlsAnother}
260+
// register the name tlsPreferred in non-preferred mode!
261+
err = RegisterTLSConfig(tlsPreferred, &tlsCfg3)
262+
if err != nil {
263+
t.Error(err.Error())
264+
}
265+
checkCfgAndPreferred(tlsPreferred, tlsAnother, false)
266+
267+
// overwrite it back to preferred
268+
err = RegisterPreferredTLSConfig(tlsPreferred, &tlsCfg3)
269+
if err != nil {
270+
t.Error(err.Error())
271+
}
272+
checkCfgAndPreferred(tlsPreferred, tlsAnother, true)
273+
274+
DeregisterTLSConfig(tlsPreferred)
275+
cfg, preferred := getTLSConfigCloneAndPreferred(tlsPreferred)
276+
if cfg != nil {
277+
t.Error("found TLS config after deregister")
278+
}
279+
if preferred {
280+
t.Error("preferred should be false after deregister")
281+
}
282+
}
283+
224284
func TestDSNTLSConfig(t *testing.T) {
225285
expectedServerName := "example.com"
226286
dsn := "tcp(example.com:1234)/?tls=true"

packets.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
223223
return nil, "", ErrOldProtocol
224224
}
225225
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
226-
if mc.cfg.TLSConfig == "preferred" {
226+
if mc.cfg.tlsIsPreferred {
227227
mc.cfg.tls = nil
228228
} else {
229229
return nil, "", ErrNoTLS

utils.go

+28-4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ import (
2525

2626
// Registry for custom tls.Configs
2727
var (
28-
tlsConfigLock sync.RWMutex
29-
tlsConfigRegistry map[string]*tls.Config
28+
tlsConfigLock sync.RWMutex
29+
tlsConfigRegistry map[string]*tls.Config
30+
tlsConfigPreferred map[string]struct{}
3031
)
3132

3233
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
@@ -55,6 +56,17 @@ var (
5556
// })
5657
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
5758
func RegisterTLSConfig(key string, config *tls.Config) error {
59+
return registerTLSConfig(key, config, false)
60+
}
61+
62+
// RegisterPreferredTLSConfig is like a RegisterTLSConfig, but when the MySQL
63+
// server does not support TLS the driver will try to connect without TLS.
64+
// It can be used as a customized "preferred" TLS configuration in the DSN.
65+
func RegisterPreferredTLSConfig(key string, config *tls.Config) error {
66+
return registerTLSConfig(key, config, true)
67+
}
68+
69+
func registerTLSConfig(key string, config *tls.Config, preferred bool) error {
5870
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
5971
return fmt.Errorf("key '%s' is reserved", key)
6072
}
@@ -63,8 +75,16 @@ func RegisterTLSConfig(key string, config *tls.Config) error {
6375
if tlsConfigRegistry == nil {
6476
tlsConfigRegistry = make(map[string]*tls.Config)
6577
}
66-
6778
tlsConfigRegistry[key] = config
79+
80+
if preferred {
81+
if tlsConfigPreferred == nil {
82+
tlsConfigPreferred = make(map[string]struct{})
83+
}
84+
tlsConfigPreferred[key] = struct{}{}
85+
} else {
86+
delete(tlsConfigPreferred, key)
87+
}
6888
tlsConfigLock.Unlock()
6989
return nil
7090
}
@@ -75,14 +95,18 @@ func DeregisterTLSConfig(key string) {
7595
if tlsConfigRegistry != nil {
7696
delete(tlsConfigRegistry, key)
7797
}
98+
if tlsConfigPreferred != nil {
99+
delete(tlsConfigPreferred, key)
100+
}
78101
tlsConfigLock.Unlock()
79102
}
80103

81-
func getTLSConfigClone(key string) (config *tls.Config) {
104+
func getTLSConfigCloneAndPreferred(key string) (config *tls.Config, preferred bool) {
82105
tlsConfigLock.RLock()
83106
if v, ok := tlsConfigRegistry[key]; ok {
84107
config = v.Clone()
85108
}
109+
_, preferred = tlsConfigPreferred[key]
86110
tlsConfigLock.RUnlock()
87111
return
88112
}

0 commit comments

Comments
 (0)