From c5207bcd4b85a4eea71712b7f2a8921ecd8e6204 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 30 Nov 2017 23:16:18 +0900 Subject: [PATCH] Fix tls=true didn't work with host without port. Fixes #717 --- dsn.go | 20 +++++++++----------- dsn_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/dsn.go b/dsn.go index f5ea0d470..47eab6945 100644 --- a/dsn.go +++ b/dsn.go @@ -94,6 +94,15 @@ func (cfg *Config) normalize() error { cfg.Addr = ensureHavePort(cfg.Addr) } + if cfg.tls != nil { + if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + } + return nil } @@ -521,10 +530,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { if boolValue { cfg.TLSConfig = "true" cfg.tls = &tls.Config{} - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - cfg.tls.ServerName = host - } } else { cfg.TLSConfig = "false" } @@ -538,13 +543,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { } if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { - if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - tlsConfig.ServerName = host - } - } - cfg.TLSConfig = name cfg.tls = tlsConfig } else { diff --git a/dsn_test.go b/dsn_test.go index 07b223f6b..7507d1201 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -177,6 +177,34 @@ func TestDSNWithCustomTLS(t *testing.T) { DeregisterTLSConfig("utils_test") } +func TestDSNTLSConfig(t *testing.T) { + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true" + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + dsn = "tcp(example.com)/?tls=true" + cfg, err = ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + } +} + func TestDSNWithCustomTLSQueryEscape(t *testing.T) { const configKey = "&%!:" dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)