Skip to content

Commit 72edd19

Browse files
nemithtz70s
authored andcommitted
move tls and pubkey object creation to Config.normalize() (go-sql-driver#958)
This is still less than ideal since we cannot directly pass in tls.Config into Config and have it be used, but it is sill backwards compatable. In the future this should be revisited to be able to use a custome tls.Config passed directly in without string parsing/registering.
1 parent 4d8a64e commit 72edd19

File tree

4 files changed

+76
-27
lines changed

4 files changed

+76
-27
lines changed

AUTHORS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,12 @@ Zhenye Xie <xiezhenye at gmail.com>
9090

9191
Barracuda Networks, Inc.
9292
Counting Ltd.
93+
Facebook Inc.
9394
GitHub Inc.
9495
Google Inc.
9596
InfoSum Ltd.
9697
Keybase Inc.
98+
Multiplay Ltd.
9799
Percona LLC
98100
Pivotal Inc.
99101
Stripe Inc.
100-
Multiplay Ltd.

dsn.go

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,35 @@ func (cfg *Config) normalize() error {
113113
default:
114114
return errors.New("default addr for network '" + cfg.Net + "' unknown")
115115
}
116-
117116
} else if cfg.Net == "tcp" {
118117
cfg.Addr = ensureHavePort(cfg.Addr)
119118
}
120119

121-
if cfg.tls != nil {
122-
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
123-
host, _, err := net.SplitHostPort(cfg.Addr)
124-
if err == nil {
125-
cfg.tls.ServerName = host
126-
}
120+
switch cfg.TLSConfig {
121+
case "false", "":
122+
// don't set anything
123+
case "true":
124+
cfg.tls = &tls.Config{}
125+
case "skip-verify", "preferred":
126+
cfg.tls = &tls.Config{InsecureSkipVerify: true}
127+
default:
128+
cfg.tls = getTLSConfigClone(cfg.TLSConfig)
129+
if cfg.tls == nil {
130+
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
131+
}
132+
}
133+
134+
if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
135+
host, _, err := net.SplitHostPort(cfg.Addr)
136+
if err == nil {
137+
cfg.tls.ServerName = host
138+
}
139+
}
140+
141+
if cfg.ServerPubKey != "" {
142+
cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
143+
if cfg.pubKey == nil {
144+
return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
127145
}
128146
}
129147

@@ -552,13 +570,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
552570
if err != nil {
553571
return fmt.Errorf("invalid value for server pub key name: %v", err)
554572
}
555-
556-
if pubKey := getServerPubKey(name); pubKey != nil {
557-
cfg.ServerPubKey = name
558-
cfg.pubKey = pubKey
559-
} else {
560-
return errors.New("invalid value / unknown server pub key name: " + name)
561-
}
573+
cfg.ServerPubKey = name
562574

563575
// Strict mode
564576
case "strict":
@@ -577,25 +589,17 @@ func parseDSNParams(cfg *Config, params string) (err error) {
577589
if isBool {
578590
if boolValue {
579591
cfg.TLSConfig = "true"
580-
cfg.tls = &tls.Config{}
581592
} else {
582593
cfg.TLSConfig = "false"
583594
}
584595
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
585596
cfg.TLSConfig = vl
586-
cfg.tls = &tls.Config{InsecureSkipVerify: true}
587597
} else {
588598
name, err := url.QueryUnescape(value)
589599
if err != nil {
590600
return fmt.Errorf("invalid value for TLS config name: %v", err)
591601
}
592-
593-
if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
594-
cfg.TLSConfig = name
595-
cfg.tls = tlsConfig
596-
} else {
597-
return errors.New("invalid value / unknown config name: " + name)
598-
}
602+
cfg.TLSConfig = name
599603
}
600604

601605
// I/O write Timeout

dsn_test.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ var testDSNs = []struct {
3939
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
4040
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"},
4141
}, {
42-
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216",
43-
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216},
42+
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true",
43+
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
4444
}, {
4545
"user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0",
4646
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false},
@@ -358,6 +358,50 @@ func TestCloneConfig(t *testing.T) {
358358
}
359359
}
360360

361+
func TestNormalizeTLSConfig(t *testing.T) {
362+
tt := []struct {
363+
tlsConfig string
364+
want *tls.Config
365+
}{
366+
{"", nil},
367+
{"false", nil},
368+
{"true", &tls.Config{ServerName: "myserver"}},
369+
{"skip-verify", &tls.Config{InsecureSkipVerify: true}},
370+
{"preferred", &tls.Config{InsecureSkipVerify: true}},
371+
{"test_tls_config", &tls.Config{ServerName: "myServerName"}},
372+
}
373+
374+
RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"})
375+
defer func() { DeregisterTLSConfig("test_tls_config") }()
376+
377+
for _, tc := range tt {
378+
t.Run(tc.tlsConfig, func(t *testing.T) {
379+
cfg := &Config{
380+
Addr: "myserver:3306",
381+
TLSConfig: tc.tlsConfig,
382+
}
383+
384+
cfg.normalize()
385+
386+
if cfg.tls == nil {
387+
if tc.want != nil {
388+
t.Fatal("wanted a tls config but got nil instead")
389+
}
390+
return
391+
}
392+
393+
if cfg.tls.ServerName != tc.want.ServerName {
394+
t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')",
395+
tc.want.ServerName, cfg.tls.ServerName)
396+
}
397+
if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify {
398+
t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)",
399+
tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify)
400+
}
401+
})
402+
}
403+
}
404+
361405
func BenchmarkParseDSN(b *testing.B) {
362406
b.ReportAllocs()
363407

utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ var (
5656
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
5757
//
5858
func RegisterTLSConfig(key string, config *tls.Config) error {
59-
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
59+
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
6060
return fmt.Errorf("key '%s' is reserved", key)
6161
}
6262

0 commit comments

Comments
 (0)