diff --git a/README.md b/README.md index 8cd6ac38c..b2f5d1739 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ Possible Parameters are: * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` * `strict`: Enable strict mode. MySQL warnings are treated as errors. * `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout). - * `tls`: `true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) + * `tls`: `true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](http://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). All other parameters are interpreted as system variables: * `autocommit`: *"SET autocommit=`value`"* diff --git a/driver_test.go b/driver_test.go index 7be65a476..246e83ee5 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1,6 +1,7 @@ package mysql import ( + "crypto/tls" "database/sql" "fmt" "io" @@ -840,7 +841,7 @@ func TestStrict(t *testing.T) { } func TestTLS(t *testing.T) { - runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) { + tlsTest := func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { if err == errNoTLS { dbt.Skip("Server does not support TLS") @@ -861,7 +862,15 @@ func TestTLS(t *testing.T) { dbt.Fatal("No Cipher") } } + } + + runTests(t, dsn+"&tls=skip-verify", tlsTest) + + // Verify that registering / using a custom cfg works + RegisterTLSConfig("custom-skip-verify", &tls.Config{ + InsecureSkipVerify: true, }) + runTests(t, dsn+"&tls=custom-skip-verify", tlsTest) } // Special cases diff --git a/utils.go b/utils.go index 097ecd0aa..e40fcb2f1 100644 --- a/utils.go +++ b/utils.go @@ -77,6 +77,46 @@ func (nt NullTime) Value() (driver.Value, error) { return nt.Time, nil } +var tlsConfigMap map[string]*tls.Config + +// Registers a custom tls.Config to be used with sql.Open. +// Use the key as a value in the DSN where tls=value. +// +// rootCertPool := x509.NewCertPool() +// pem, err := ioutil.ReadFile("/path/ca-cert.pem") +// if err != nil { +// log.Fatal(err) +// } +// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { +// log.Fatal("Failed to append PEM.") +// } +// clientCert := make([]tls.Certificate, 0, 1) +// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") +// if err != nil { +// log.Fatal(err) +// } +// clientCert = append(clientCert, certs) +// mysql.RegisterTLSConfig("custom", &tls.Config{ +// RootCAs: rootCertPool, +// Certificates: clientCert, +// }) +// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") +// +func RegisterTLSConfig(key string, config *tls.Config) { + if tlsConfigMap == nil { + tlsConfigMap = make(map[string]*tls.Config) + } + tlsConfigMap[key] = config +} + +// Removes tls.Config associated with key. +func DeregisterTLSConfig(key string) { + if tlsConfigMap == nil { + return + } + delete(tlsConfigMap, key) +} + // Logger var ( errLog *log.Logger @@ -152,6 +192,9 @@ func parseDSN(dsn string) (cfg *config, err error) { cfg.tls = &tls.Config{} } else if strings.ToLower(value) == "skip-verify" { cfg.tls = &tls.Config{InsecureSkipVerify: true} + // TODO: Check for Boolean false + } else if tlsConfig, ok := tlsConfigMap[value]; ok { + cfg.tls = tlsConfig } default: