Skip to content

Allow TLS connections in the driver #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions client/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,28 @@ func NewClientTLSConfig(caPem, certPem, keyPem []byte, insecureSkipVerify bool,
panic("failed to add ca PEM")
}

cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
panic(err)
}
var config *tls.Config

config := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: pool,
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
// Allow cert and key to be optional
// Send through `make([]byte, 0)` for "nil"
if string(certPem) != "" && string(keyPem) != "" {
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
panic(err)
}
config = &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
}
} else {
config = &tls.Config{
RootCAs: pool,
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
}
}

return config
}
75 changes: 70 additions & 5 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
package driver

import (
"crypto/tls"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"io"
"net/url"
"strings"

"github.com/go-mysql-org/go-mysql/client"
Expand All @@ -15,6 +17,9 @@ import (
"github.com/siddontang/go/hack"
)

// Map of dsn address (makes more sense than full dsn?) to tls Config
var customTLSConfigMap = make(map[string]*tls.Config)

type driver struct {
}

Expand All @@ -23,31 +28,74 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
lastIndex := strings.LastIndex(dsn, "@")
seps := []string{dsn[:lastIndex], dsn[lastIndex+1:]}
if len(seps) != 2 {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
return nil, errors.Errorf("invalid dsn, must user:password@addr[[?db[&param=X]]")
}

var user string
var password string
var addr string
var db string
var err error
var c *client.Conn

if ss := strings.Split(seps[0], ":"); len(ss) == 2 {
user, password = ss[0], ss[1]
} else if len(ss) == 1 {
user = ss[0]
} else {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
return nil, errors.Errorf("invalid dsn, must user:password@addr[[?db[&param=X]]")
}

params := make(map[string]string)
if ss := strings.Split(seps[1], "?"); len(ss) == 2 {
addr, db = ss[0], ss[1]
// If the dsn used a `/` for the path separator this would be easier to parse
// with `url.Parse` and we could use `.Path` to get the db and then use
// `Query()` to get the parameters and values.
// But for consistency with the current way of doing things...
addr = ss[0]
dbAndParams := ss[1]
if ss := strings.Split(dbAndParams, "&"); len(ss) == 1 {
db = ss[0]
} else {
// We have to assume the first is the db
// Then need to handle possible multiple parameters / query strings
for i, p := range ss {
if i == 0 {
db = p
} else {
// Build key value pairs
kv := strings.Split(p, "=")
params[kv[0]] = kv[1]
}
}
}
} else if len(ss) == 1 {
addr = ss[0]
} else {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
return nil, errors.Errorf("invalid dsn, must user:password@addr[[?db[&param=X]]")
}

c, err := client.Connect(addr, user, password, db)
tlsConfigName, tls := params["ssl"]
if tls {
switch tlsConfigName {
case "true":
// This actually does insecureSkipVerify
// But not even sure if it makes sense to handle false? According to
// client_test.go it doesn't - it'd result in an error
c, err = client.Connect(addr, user, password, db, func(c *client.Conn) { c.UseSSL(true) })
case "custom":
// I was too concerned about mimicking what go-sql-driver/mysql does which will
// allow any name for a custom tls profile and maps the query parameter value to
// that TLSConfig variable... there is no need to be that clever.
// Instead of doing that, let's store required custom TLSConfigs in a map that
// uses the DSN address as the key
c, err = client.Connect(addr, user, password, db, func(c *client.Conn) { c.SetTLSConfig(customTLSConfigMap[addr]) })
default:
return nil, errors.Errorf("Supported options are ssl=true or ssl=custom")
}
} else {
c, err = client.Connect(addr, user, password, db)
}
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -229,3 +277,20 @@ func (r *rows) Next(dest []sqldriver.Value) error {
func init() {
sql.Register("mysql", driver{})
}

func SetCustomTLSConfig(dsn string, caPem []byte, certPem []byte, keyPem []byte, insecureSkipVerify bool, serverName string) {
// Extract addr from dsn
// We can hopefully extend the use of url.Parse if we switch the DSN style
parsed, err := url.Parse(dsn)
if err != nil {
errors.Errorf("Unable to parse DSN. Need to extract address to use as key for storing custom TLS config")
}
addr := parsed.Host

// I thought about using serverName instead of addr below, but decided against that as
// having multiple CA certs for one hostname is likely when you have services running on
// different ports.

// Basic pass-through function so we can just import the driver
customTLSConfigMap[addr] = client.NewClientTLSConfig(caPem, certPem, keyPem, insecureSkipVerify, serverName)
}