Skip to content

move tls/pubkey object creation to Config.normalize() #958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ Zhenye Xie <xiezhenye at gmail.com>

Barracuda Networks, Inc.
Counting Ltd.
Facebook Inc.
GitHub Inc.
Google Inc.
InfoSum Ltd.
Keybase Inc.
Multiplay Ltd.
Percona LLC
Pivotal Inc.
Stripe Inc.
Multiplay Ltd.
50 changes: 27 additions & 23 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,35 @@ func (cfg *Config) normalize() error {
default:
return errors.New("default addr for network '" + cfg.Net + "' unknown")
}

} else if cfg.Net == "tcp" {
cfg.Addr = ensureHavePort(cfg.Addr)
}

if cfg.tls != nil {
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
cfg.tls.ServerName = host
}
switch cfg.TLSConfig {
case "false", "":
// don't set anything
case "true":
cfg.tls = &tls.Config{}
case "skip-verify", "preferred":
cfg.tls = &tls.Config{InsecureSkipVerify: true}
default:
cfg.tls = getTLSConfigClone(cfg.TLSConfig)
if cfg.tls == nil {
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
}
}

if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
cfg.tls.ServerName = host
}
}

if cfg.ServerPubKey != "" {
cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
if cfg.pubKey == nil {
return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
}
}

Expand Down Expand Up @@ -552,13 +570,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return fmt.Errorf("invalid value for server pub key name: %v", err)
}

if pubKey := getServerPubKey(name); pubKey != nil {
cfg.ServerPubKey = name
cfg.pubKey = pubKey
} else {
return errors.New("invalid value / unknown server pub key name: " + name)
}
cfg.ServerPubKey = name

// Strict mode
case "strict":
Expand All @@ -577,25 +589,17 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if isBool {
if boolValue {
cfg.TLSConfig = "true"
cfg.tls = &tls.Config{}
} else {
cfg.TLSConfig = "false"
}
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
cfg.TLSConfig = vl
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else {
name, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid value for TLS config name: %v", err)
}

if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
cfg.TLSConfig = name
cfg.tls = tlsConfig
} else {
return errors.New("invalid value / unknown config name: " + name)
}
cfg.TLSConfig = name
}

// I/O write Timeout
Expand Down
48 changes: 46 additions & 2 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ var testDSNs = []struct {
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"},
}, {
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216},
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
}, {
"user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false},
Expand Down Expand Up @@ -358,6 +358,50 @@ func TestCloneConfig(t *testing.T) {
}
}

func TestNormalizeTLSConfig(t *testing.T) {
tt := []struct {
tlsConfig string
want *tls.Config
}{
{"", nil},
{"false", nil},
{"true", &tls.Config{ServerName: "myserver"}},
{"skip-verify", &tls.Config{InsecureSkipVerify: true}},
{"preferred", &tls.Config{InsecureSkipVerify: true}},
{"test_tls_config", &tls.Config{ServerName: "myServerName"}},
}

RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"})
defer func() { DeregisterTLSConfig("test_tls_config") }()

for _, tc := range tt {
t.Run(tc.tlsConfig, func(t *testing.T) {
cfg := &Config{
Addr: "myserver:3306",
TLSConfig: tc.tlsConfig,
}

cfg.normalize()

if cfg.tls == nil {
if tc.want != nil {
t.Fatal("wanted a tls config but got nil instead")
}
return
}

if cfg.tls.ServerName != tc.want.ServerName {
t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')",
tc.want.ServerName, cfg.tls.ServerName)
}
if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify {
t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)",
tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify)
}
})
}
}

func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()

Expand Down
2 changes: 1 addition & 1 deletion utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ var (
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
//
func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
return fmt.Errorf("key '%s' is reserved", key)
}

Expand Down