diff --git a/AUTHORS b/AUTHORS index ad5989800..c144c46ab 100644 --- a/AUTHORS +++ b/AUTHORS @@ -103,3 +103,4 @@ Multiplay Ltd. Percona LLC Pivotal Inc. Stripe Inc. +Zendesk Inc. diff --git a/README.md b/README.md index 2d15ffda3..5a3489031 100644 --- a/README.md +++ b/README.md @@ -367,6 +367,31 @@ Examples: * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` * [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'` +#### Non-DSN parameters + +Some parameters (those that have types too complex to fit into a string) are not supported as part of a DSN string, but can only be specified by using the Connector interface. To use these parameters, set your database client up like so: + +```go +dbConfig := mysql.Config { + Addr: "localhost:3306", + // ... other parameters ... +} +connector, err := mysql.NewConnector(dbConfig) +if err != nil { + panic(error) +} +db := sql.OpenDB(connector) +``` + +##### `CredentialProvider` + +``` +Type: CredentialProviderFunc +Default: nil +``` + +If set, this must refer to a credential provider function of type `CredentialProviderFunc`. When this is set, the `User` and `Passwd` fields in the config will be ignored; instead, each time a connection is to be opened, the credential provider function will be called to obtain a username/password to connect with. This is useful when using, for example, IAM database auth in Amazon AWS, where "passwords" are actually temporary tokens that expire. + #### Examples ``` diff --git a/auth.go b/auth.go index fec7040d4..37606d657 100644 --- a/auth.go +++ b/auth.go @@ -81,6 +81,11 @@ func getServerPubKey(name string) (pubKey *rsa.PublicKey) { return } +// CredentialProviderFunc is a function which can be used to fetch a username/password +// pair for use when opening a new MySQL connection. The first return value is the username +// and the second the password. +type CredentialProviderFunc func() (user string, password string, error error) + // Hash password using pre 4.1 (old password) method // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c type myRnd struct { @@ -237,10 +242,10 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro return mc.writeAuthSwitchPacket(enc) } -func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { +func (mc *mysqlConn) auth(authData []byte, plugin string, password string) ([]byte, error) { switch plugin { case "caching_sha2_password": - authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) + authResp := scrambleSHA256Password(authData, password) return authResp, nil case "mysql_old_password": @@ -250,7 +255,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 - authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + authResp := append(scrambleOldPassword(authData[:8], password), 0) return authResp, nil case "mysql_clear_password": @@ -259,7 +264,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - return append([]byte(mc.cfg.Passwd), 0), nil + return append([]byte(password), 0), nil case "mysql_native_password": if !mc.cfg.AllowNativePasswords { @@ -267,16 +272,16 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. - authResp := scramblePassword(authData[:20], mc.cfg.Passwd) + authResp := scramblePassword(authData[:20], password) return authResp, nil case "sha256_password": - if len(mc.cfg.Passwd) == 0 { + if len(password) == 0 { return []byte{0}, nil } if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - return append([]byte(mc.cfg.Passwd), 0), nil + return append([]byte(password), 0), nil } pubKey := mc.cfg.pubKey @@ -286,7 +291,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // encrypted password - enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + enc, err := encryptPassword(password, authData, pubKey) return enc, err default: @@ -295,7 +300,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } } -func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { +func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string, password string) error { // Read Result Packet authData, newPlugin, err := mc.readAuthResult() if err != nil { @@ -315,7 +320,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { plugin = newPlugin - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, password) if err != nil { return err } @@ -352,7 +357,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case cachingSha2PasswordPerformFullAuthentication: if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) + err = mc.writeAuthSwitchPacket(append([]byte(password), 0)) if err != nil { return err } diff --git a/auth_test.go b/auth_test.go index 1920ef39f..c1d08a454 100644 --- a/auth_test.go +++ b/auth_test.go @@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -115,7 +115,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -157,7 +157,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -208,7 +208,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { conn.maxReads = 3 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -261,7 +261,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -317,7 +317,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { conn.maxReads = 3 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -380,7 +380,7 @@ func TestAuthFastCleartextPassword(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -423,7 +423,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -483,7 +483,7 @@ func TestAuthFastNativePassword(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -525,7 +525,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -569,7 +569,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -588,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -617,7 +617,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -637,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -651,7 +651,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -670,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { plugin := "sha256_password" // send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } @@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.tls = nil - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -699,7 +699,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -728,7 +728,7 @@ func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -761,7 +761,7 @@ func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -797,7 +797,7 @@ func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -842,7 +842,7 @@ func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -885,7 +885,7 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -912,7 +912,7 @@ func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -935,7 +935,7 @@ func TestAuthSwitchCleartextPassword(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -962,7 +962,7 @@ func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -984,7 +984,7 @@ func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, 31} plugin := "caching_sha2_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -1009,7 +1009,7 @@ func TestAuthSwitchNativePassword(t *testing.T) { 48, 31, 89, 39, 55, 31} plugin := "caching_sha2_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1039,7 +1039,7 @@ func TestAuthSwitchNativePasswordEmpty(t *testing.T) { 48, 31, 89, 39, 55, 31} plugin := "caching_sha2_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1059,7 +1059,7 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) } @@ -1075,7 +1075,7 @@ func TestOldAuthSwitchNotAllowed(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) } @@ -1099,7 +1099,7 @@ func TestAuthSwitchOldPassword(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1126,7 +1126,7 @@ func TestOldAuthSwitch(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1153,7 +1153,7 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1180,7 +1180,7 @@ func TestOldAuthSwitchPasswordEmpty(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1209,7 +1209,7 @@ func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1244,7 +1244,7 @@ func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1280,7 +1280,7 @@ func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1316,7 +1316,7 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } diff --git a/connector.go b/connector.go index d567b4e4f..9a5088157 100644 --- a/connector.go +++ b/connector.go @@ -88,25 +88,32 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { plugin = defaultAuthPlugin } + user, password, err := c.cfg.getCredentials() + if err != nil { + mc.cleanup() + return nil, err + } + // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, password) if err != nil { // try the default auth plugin, if using the requested plugin failed errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin - authResp, err = mc.auth(authData, plugin) + authResp, err = mc.auth(authData, plugin, password) if err != nil { mc.cleanup() return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + + if err = mc.writeHandshakeResponsePacket(authResp, plugin, user); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible - if err = mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, plugin, password); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. diff --git a/driver_test.go b/driver_test.go index ace083dfc..5b5fc7c2e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -125,36 +125,47 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT } func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + cfg, err := ParseDSN(dsn) + if err != nil { + t.Fatalf("error formatting DSN") + } + runTestsWithConfig(t, cfg, tests...) +} + +func runTestsWithConfig(t *testing.T, cfg *Config, tests ...func(dbt *DBTest)) { if !available { t.Skipf("MySQL server not running on %s", netAddr) } - db, err := sql.Open("mysql", dsn) + connector, err := NewConnector(cfg) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } + db := sql.OpenDB(connector) defer db.Close() db.Exec("DROP TABLE IF EXISTS test") - dsn2 := dsn + "&interpolateParams=true" + cfg2 := cfg.Clone() + cfg2.InterpolateParams = true var db2 *sql.DB - if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { - db2, err = sql.Open("mysql", dsn2) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } + connector2, err := NewConnector(cfg2) + if err != errInvalidDSNUnsafeCollation { + db2 = sql.OpenDB(connector2) defer db2.Close() + } else if err != nil { + t.Fatalf("error connecting: %s", err.Error()) } - dsn3 := dsn + "&multiStatements=true" + cfg3 := cfg.Clone() + cfg3.MultiStatements = true var db3 *sql.DB - if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { - db3, err = sql.Open("mysql", dsn3) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } + connector3, err := NewConnector(cfg2) + if err != errInvalidDSNUnsafeCollation { + db3 = sql.OpenDB(connector3) defer db3.Close() + } else if err != nil { + t.Fatalf("error connecting: %s", err.Error()) } dbt := &DBTest{t, db} @@ -3163,3 +3174,80 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) { t.Errorf("connection not closed") } } + +func TestCredentialProviderFunc(t *testing.T) { + // Our test provider func should return a valid password, then an invalid one, then a valid one + // to test that it really is having an effect. + shouldFailCreds := false + shouldFailError := false + cfg := &Config{ + Addr: addr, + Net: prot, + DBName: dbname, + Collation: defaultCollation, + AllowNativePasswords: true, + CredentialProvider: func() (string, string, error) { + if shouldFailCreds { + return "fail", "fail", nil + } + if shouldFailError { + return "", "", fmt.Errorf("credential_error") + } + return user, pass, nil + }, + } + runTestsWithConfig(t, cfg, func(dbt *DBTest) { + ctx := context.Background() + c1, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatalf("error opening conn: %s", err) + } + defer c1.Close() + + rows, err := c1.QueryContext(ctx, "SELECT USER()") + if err != nil { + dbt.Fatalf("error running SELECT USER(): %s", err) + } + connUserAndHost := "" + for rows.Next() { + err := rows.Scan(&connUserAndHost) + if err != nil { + dbt.Fatalf("error running query: %s", err) + } + } + parts := strings.Split(connUserAndHost, "@") + connUser := strings.Join(parts[:len(parts)-1], "@") + if connUser != user { + dbt.Errorf("USER() and credentials don't match: %s != %s", connUser, user) + } + + // open one that should fail (wrong creds) + shouldFailCreds = true + _, err = dbt.db.Conn(ctx) + shouldFailCreds = false + if err == nil { + dbt.Errorf("expected second open to fail") + } + + // open one that should fail (with an error) + shouldFailError = true + _, err = dbt.db.Conn(ctx) + shouldFailError = false + if err == nil { + dbt.Errorf("expected third open to fail") + } + if !strings.Contains(err.Error(), "credential_error") { + dbt.Errorf("expected third open to fail with credential_error") + } + + c4, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatalf("error opening conn: %s", err) + } + defer c4.Close() + err = c4.PingContext(ctx) + if err != nil { + dbt.Errorf("error running PingContext: %s", err) + } + }) +} diff --git a/dsn.go b/dsn.go index 1d9b4ab0a..00fc8ca9a 100644 --- a/dsn.go +++ b/dsn.go @@ -34,22 +34,23 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key - TLSConfig string // TLS configuration name - tls *tls.Config // TLS configuration - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout + User string // Username + Passwd string // Password (requires User) + CredentialProvider CredentialProviderFunc // Credential provider function + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -628,6 +629,13 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } +func (cfg *Config) getCredentials() (user string, password string, err error) { + if cfg.CredentialProvider != nil { + return cfg.CredentialProvider() + } + return cfg.User, cfg.Passwd, nil +} + func ensureHavePort(addr string) string { if _, _, err := net.SplitHostPort(addr); err != nil { return net.JoinHostPort(addr, "3306") diff --git a/dsn_test.go b/dsn_test.go index 50dc2932c..2f5ab658f 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -71,8 +71,7 @@ var testDSNs = []struct { }, { "tcp(de:ad:be:ef::ca:fe)/dbname", &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, -}, -} +}} func TestDSNParser(t *testing.T) { for i, tst := range testDSNs { diff --git a/packets.go b/packets.go index 30b3352c2..18b0d3731 100644 --- a/packets.go +++ b/packets.go @@ -276,7 +276,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string, user string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -310,7 +310,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientFlags |= clientPluginAuthLenEncClientData } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(user) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -373,8 +373,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // User [null terminated string] - if len(mc.cfg.User) > 0 { - pos += copy(data[pos:], mc.cfg.User) + if len(user) > 0 { + pos += copy(data[pos:], user) } data[pos] = 0x00 pos++