Skip to content

Commit 145f684

Browse files
authored
Merge pull request #673 from atomicules/tls-in-driver
Allow TLS connections in the driver
2 parents 5a98cc1 + b490dc8 commit 145f684

File tree

3 files changed

+167
-32
lines changed

3 files changed

+167
-32
lines changed

client/tls.go

+21-9
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,28 @@ func NewClientTLSConfig(caPem, certPem, keyPem []byte, insecureSkipVerify bool,
1313
panic("failed to add ca PEM")
1414
}
1515

16-
cert, err := tls.X509KeyPair(certPem, keyPem)
17-
if err != nil {
18-
panic(err)
19-
}
16+
var config *tls.Config
2017

21-
config := &tls.Config{
22-
Certificates: []tls.Certificate{cert},
23-
RootCAs: pool,
24-
InsecureSkipVerify: insecureSkipVerify,
25-
ServerName: serverName,
18+
// Allow cert and key to be optional
19+
// Send through `make([]byte, 0)` for "nil"
20+
if string(certPem) != "" && string(keyPem) != "" {
21+
cert, err := tls.X509KeyPair(certPem, keyPem)
22+
if err != nil {
23+
panic(err)
24+
}
25+
config = &tls.Config{
26+
RootCAs: pool,
27+
Certificates: []tls.Certificate{cert},
28+
InsecureSkipVerify: insecureSkipVerify,
29+
ServerName: serverName,
30+
}
31+
} else {
32+
config = &tls.Config{
33+
RootCAs: pool,
34+
InsecureSkipVerify: insecureSkipVerify,
35+
ServerName: serverName,
36+
}
2637
}
38+
2739
return config
2840
}

driver/driver.go

+119-23
Original file line numberDiff line numberDiff line change
@@ -3,51 +3,120 @@
33
package driver
44

55
import (
6+
"crypto/tls"
67
"database/sql"
78
sqldriver "database/sql/driver"
89
"fmt"
910
"io"
10-
"strings"
11+
"net/url"
12+
"regexp"
13+
"sync"
1114

1215
"github.com/go-mysql-org/go-mysql/client"
1316
"github.com/go-mysql-org/go-mysql/mysql"
1417
"github.com/pingcap/errors"
1518
"github.com/siddontang/go/hack"
1619
)
1720

21+
var customTLSMutex sync.Mutex
22+
23+
// Map of dsn address (makes more sense than full dsn?) to tls Config
24+
var customTLSConfigMap = make(map[string]*tls.Config)
25+
1826
type driver struct {
1927
}
2028

21-
// Open: DSN user:password@addr[?db]
22-
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
23-
lastIndex := strings.LastIndex(dsn, "@")
24-
seps := []string{dsn[:lastIndex], dsn[lastIndex+1:]}
25-
if len(seps) != 2 {
26-
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
29+
type connInfo struct {
30+
standardDSN bool
31+
addr string
32+
user string
33+
password string
34+
db string
35+
params url.Values
36+
}
37+
38+
// ParseDSN takes a DSN string and splits it up into struct containing addr,
39+
// user, password and db.
40+
// It returns an error if unable to parse.
41+
// The struct also contains a boolean indicating if the DSN is in legacy or
42+
// standard form.
43+
//
44+
// Legacy form uses a `?` is used as the path separator: user:password@addr[?db]
45+
// Standard form uses a `/`: user:password@addr/db?param=value
46+
//
47+
// Optional parameters are supported in the standard DSN form
48+
func parseDSN(dsn string) (connInfo, error) {
49+
var matchErr error
50+
ci := connInfo{}
51+
52+
// If a "/" occurs after "@" and then no more "@" or "/" occur after that
53+
ci.standardDSN, matchErr = regexp.MatchString("@[^@]+/[^@/]+", dsn)
54+
if matchErr != nil {
55+
return ci, errors.Errorf("invalid dsn, must be user:password@addr[/db[?param=X]]")
56+
}
57+
58+
// Add a prefix so we can parse with url.Parse
59+
dsn = "mysql://" + dsn
60+
parsedDSN, parseErr := url.Parse(dsn)
61+
if parseErr != nil {
62+
return ci, errors.Errorf("invalid dsn, must be user:password@addr[/db[?param=X]]")
2763
}
2864

29-
var user string
30-
var password string
31-
var addr string
32-
var db string
65+
ci.addr = parsedDSN.Host
66+
ci.user = parsedDSN.User.Username()
67+
// We ignore the second argument as that is just a flag for existence of a password
68+
// If not set we get empty string anyway
69+
ci.password, _ = parsedDSN.User.Password()
3370

34-
if ss := strings.Split(seps[0], ":"); len(ss) == 2 {
35-
user, password = ss[0], ss[1]
36-
} else if len(ss) == 1 {
37-
user = ss[0]
71+
if ci.standardDSN {
72+
ci.db = parsedDSN.Path[1:]
73+
ci.params = parsedDSN.Query()
3874
} else {
39-
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
75+
ci.db = parsedDSN.RawQuery
76+
// This is the equivalent to a "nil" list of parameters
77+
ci.params = url.Values{}
4078
}
4179

42-
if ss := strings.Split(seps[1], "?"); len(ss) == 2 {
43-
addr, db = ss[0], ss[1]
44-
} else if len(ss) == 1 {
45-
addr = ss[0]
46-
} else {
47-
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
80+
return ci, nil
81+
}
82+
83+
// Open takes a supplied DSN string and opens a connection
84+
// See ParseDSN for more information on the form of the DSN
85+
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
86+
var c *client.Conn
87+
88+
ci, err := parseDSN(dsn)
89+
90+
if err != nil {
91+
return nil, err
4892
}
4993

50-
c, err := client.Connect(addr, user, password, db)
94+
if ci.standardDSN {
95+
if ci.params["ssl"] != nil {
96+
tlsConfigName := ci.params.Get("ssl")
97+
switch tlsConfigName {
98+
case "true":
99+
// This actually does insecureSkipVerify
100+
// But not even sure if it makes sense to handle false? According to
101+
// client_test.go it doesn't - it'd result in an error
102+
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.UseSSL(true) })
103+
case "custom":
104+
// I was too concerned about mimicking what go-sql-driver/mysql does which will
105+
// allow any name for a custom tls profile and maps the query parameter value to
106+
// that TLSConfig variable... there is no need to be that clever.
107+
// Instead of doing that, let's store required custom TLSConfigs in a map that
108+
// uses the DSN address as the key
109+
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.SetTLSConfig(customTLSConfigMap[ci.addr]) })
110+
default:
111+
return nil, errors.Errorf("Supported options are ssl=true or ssl=custom")
112+
}
113+
} else {
114+
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
115+
}
116+
} else {
117+
// No more processing here. Let's only support url parameters with the newer style DSN
118+
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
119+
}
51120
if err != nil {
52121
return nil, err
53122
}
@@ -229,3 +298,30 @@ func (r *rows) Next(dest []sqldriver.Value) error {
229298
func init() {
230299
sql.Register("mysql", driver{})
231300
}
301+
302+
// SetCustomTLSConfig sets a custom TLSConfig for the address (host:port) of the supplied DSN.
303+
// It requires a full import of the driver (not by side-effects only).
304+
// Example of supplying a custom CA, no client cert, no key, validating the
305+
// certificate, and supplying a serverName for the validation:
306+
//
307+
// driver.SetCustomTLSConfig(CaPem, make([]byte, 0), make([]byte, 0), false, "my.domain.name")
308+
//
309+
func SetCustomTLSConfig(dsn string, caPem []byte, certPem []byte, keyPem []byte, insecureSkipVerify bool, serverName string) error {
310+
// Extract addr from dsn
311+
parsed, err := url.Parse(dsn)
312+
if err != nil {
313+
return errors.Errorf("Unable to parse DSN. Need to extract address to use as key for storing custom TLS config")
314+
}
315+
addr := parsed.Host
316+
317+
// I thought about using serverName instead of addr below, but decided against that as
318+
// having multiple CA certs for one hostname is likely when you have services running on
319+
// different ports.
320+
321+
customTLSMutex.Lock()
322+
// Basic pass-through function so we can just import the driver
323+
customTLSConfigMap[addr] = client.NewClientTLSConfig(caPem, certPem, keyPem, insecureSkipVerify, serverName)
324+
customTLSMutex.Unlock()
325+
326+
return nil
327+
}

driver/driver_test.go

+27
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package driver
33
import (
44
"flag"
55
"fmt"
6+
"net/url"
7+
"reflect"
68
"testing"
79

810
"github.com/jmoiron/sqlx"
@@ -78,3 +80,28 @@ func (s *testDriverSuite) TestTransaction(c *C) {
7880
err = tx.Commit()
7981
c.Assert(err, IsNil)
8082
}
83+
84+
func TestParseDSN(t *testing.T) {
85+
// List of DSNs to test and expected results
86+
// Use different numbered domains to more readily see what has failed - since we
87+
// test in a loop we get the same line number on error
88+
testDSNs := map[string]connInfo{
89+
"user:password@localhost?db": connInfo{standardDSN: false, addr: "localhost", user: "user", password: "password", db: "db", params: url.Values{}},
90+
"[email protected]?db": connInfo{standardDSN: false, addr: "1.domain.com", user: "user", password: "", db: "db", params: url.Values{}},
91+
"user:[email protected]/db": connInfo{standardDSN: true, addr: "2.domain.com", user: "user", password: "password", db: "db", params: url.Values{}},
92+
"user:[email protected]/db?ssl=true": connInfo{standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"true"}}},
93+
"user:[email protected]/db?ssl=custom": connInfo{standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"custom"}}},
94+
"user:[email protected]/db?unused=param": connInfo{standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"unused": []string{"param"}}},
95+
}
96+
97+
for supplied, expected := range testDSNs {
98+
actual, err := parseDSN(supplied)
99+
if err != nil {
100+
t.Errorf("TestParseDSN failed. Got error: %s", err)
101+
}
102+
// Compare that with expected
103+
if !reflect.DeepEqual(actual, expected) {
104+
t.Errorf("TestParseDSN failed.\nExpected:\n%#v\nGot:\n%#v", expected, actual)
105+
}
106+
}
107+
}

0 commit comments

Comments
 (0)