Skip to content

Commit 8dc06d8

Browse files
committed
Default TLS ServerName to the host in the DSN.
A TLS configuration must either have a ServerName or specify InsecureSkipVerify. In most cases, the ServerName value will match the host part of the address in the DSN. This change updates the DSN parser to default the ServerName to the host value provided unless InsecureSkipVerify is specified.
1 parent 9543750 commit 8dc06d8

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@ Xiuming Chen <cc at cxm.cc>
3434

3535
Barracuda Networks, Inc.
3636
Google Inc.
37+
Stripe Inc.

utils.go

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"errors"
1717
"fmt"
1818
"io"
19+
"net"
1920
"net/url"
2021
"strings"
2122
"time"
@@ -244,6 +245,13 @@ func parseDSNParams(cfg *config, params string) (err error) {
244245
if strings.ToLower(value) == "skip-verify" {
245246
cfg.tls = &tls.Config{InsecureSkipVerify: true}
246247
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
248+
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
249+
host, _, err := net.SplitHostPort(cfg.addr)
250+
if err == nil {
251+
tlsConfig.ServerName = host
252+
}
253+
}
254+
247255
cfg.tls = tlsConfig
248256
} else {
249257
return fmt.Errorf("Invalid value / unknown config name: %s", value)

utils_test.go

+41
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"bytes"
13+
"crypto/tls"
1314
"encoding/binary"
1415
"fmt"
1516
"testing"
@@ -74,6 +75,46 @@ func TestDSNParserInvalid(t *testing.T) {
7475
}
7576
}
7677

78+
func TestDSNWithCustomTLS(t *testing.T) {
79+
baseDSN := "user:password@tcp(localhost:5555)/dbname?tls="
80+
tlsCfg := tls.Config{}
81+
82+
RegisterTLSConfig("utils_test", &tlsCfg)
83+
84+
// Custom TLS is missing
85+
tst := baseDSN + "invalid_tls"
86+
cfg, err := parseDSN(tst)
87+
if err == nil {
88+
t.Errorf("Invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg)
89+
}
90+
91+
tst = baseDSN + "utils_test"
92+
93+
// Custom TLS with a server name
94+
name := "foohost"
95+
tlsCfg.ServerName = name
96+
cfg, err = parseDSN(tst)
97+
98+
if err != nil {
99+
t.Error(err.Error())
100+
} else if cfg.tls.ServerName != name {
101+
t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
102+
}
103+
104+
// Custom TLS without a server name
105+
name = "localhost"
106+
tlsCfg.ServerName = ""
107+
cfg, err = parseDSN(tst)
108+
109+
if err != nil {
110+
t.Error(err.Error())
111+
} else if cfg.tls.ServerName != name {
112+
t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
113+
}
114+
115+
DeregisterTLSConfig("utils_test")
116+
}
117+
77118
func BenchmarkParseDSN(b *testing.B) {
78119
b.ReportAllocs()
79120

0 commit comments

Comments
 (0)