Skip to content

Commit 33d6e68

Browse files
author
KJ Tsanaktsidis
committed
Allow registering custom CredentialProviders for per-conn passwords
When using a temporary credential system for MySQL, for example IAM database authenticaiton on AWS or the Database secret backend for Hashicorp Vault, it may not be the case that the same username and password be used for opening every connection in a *sql.DB. This PR adds funcionality whereby the caller can, instead of specifying cfg.User and cfg.Passwd (in the DSN as user:pass@...), specify a CredentialProvider= arguemnt which refers to a callback registered with RegisterCredentialProvider. When a new connection is to be opened, if the CredentialProvider callback is specified, that is called to obtain a username/password pair rather than using the values from the DSN.
1 parent b66d043 commit 33d6e68

File tree

9 files changed

+247
-101
lines changed

9 files changed

+247
-101
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,4 @@ Multiplay Ltd.
103103
Percona LLC
104104
Pivotal Inc.
105105
Stripe Inc.
106+
Zendesk Inc.

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,16 @@ SELECT u.id FROM users as u
209209

210210
will return `u.id` instead of just `id` if `columnsWithAlias=true`.
211211

212+
#### `credentialProvider`
213+
214+
```
215+
Type: string
216+
Valid Values: <name>
217+
Default: ""
218+
```
219+
220+
If set, this must refer to a credential provider name registerd with `RegisterCredentialProvider`. When this is set, the username and password in the DSN will be ignored; instead, each time a conneciton is to be opened, the named credential provider function will be called to obtain a username/password to connect with. This is useful when using, for example, IAM database auth in Amazon AWS, where "passwords" are actually temporary tokens that expire.
221+
212222
##### `interpolateParams`
213223

214224
```

auth.go

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@ import (
1515
"crypto/sha256"
1616
"crypto/x509"
1717
"encoding/pem"
18+
"fmt"
1819
"sync"
1920
)
2021

2122
// server pub keys registry
2223
var (
23-
serverPubKeyLock sync.RWMutex
24-
serverPubKeyRegistry map[string]*rsa.PublicKey
24+
serverPubKeyLock sync.RWMutex
25+
serverPubKeyRegistry map[string]*rsa.PublicKey
26+
credentialProviderLock sync.RWMutex
27+
credentialProviderRetistry map[string]CredentialProviderFunc
2528
)
2629

2730
// RegisterServerPubKey registers a server RSA public key which can be used to
@@ -81,6 +84,44 @@ func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
8184
return
8285
}
8386

87+
// CredentialProviderFunc is a function which can be used to fetch a username/password
88+
// pair for use when opening a new MySQL connection. The first return value is the username
89+
// and the second the password.
90+
type CredentialProviderFunc func() (string, string, error)
91+
92+
// RegisterCredentialProvider registers a function to be called on every connection open to
93+
// get the username and password to call
94+
func RegisterCredentialProvider(name string, providerFunc CredentialProviderFunc) {
95+
credentialProviderLock.Lock()
96+
if credentialProviderRetistry == nil {
97+
credentialProviderRetistry = make(map[string]CredentialProviderFunc)
98+
}
99+
credentialProviderRetistry[name] = providerFunc
100+
credentialProviderLock.Unlock()
101+
}
102+
103+
// DeregisterCredentialProvider removes a function registered with RegisterCredentialProvider
104+
func DeregisterCredentialProvider(name string) {
105+
credentialProviderLock.Lock()
106+
if credentialProviderRetistry != nil {
107+
delete(credentialProviderRetistry, name)
108+
}
109+
credentialProviderLock.Unlock()
110+
}
111+
112+
func getCredentialsFromConfig(cfg *Config) (string, string, error) {
113+
if cfg.CredentialProvider != "" {
114+
credentialProviderLock.RLock()
115+
defer credentialProviderLock.RUnlock()
116+
cpFunc, ok := credentialProviderRetistry[cfg.CredentialProvider]
117+
if !ok {
118+
return "", "", fmt.Errorf("credential provider %s not registered", cfg.CredentialProvider)
119+
}
120+
return cpFunc()
121+
}
122+
return cfg.User, cfg.Passwd, nil
123+
}
124+
84125
// Hash password using pre 4.1 (old password) method
85126
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
86127
type myRnd struct {
@@ -237,10 +278,10 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro
237278
return mc.writeAuthSwitchPacket(enc)
238279
}
239280

240-
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
281+
func (mc *mysqlConn) auth(authData []byte, plugin string, password string) ([]byte, error) {
241282
switch plugin {
242283
case "caching_sha2_password":
243-
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
284+
authResp := scrambleSHA256Password(authData, password)
244285
return authResp, nil
245286

246287
case "mysql_old_password":
@@ -250,7 +291,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
250291
// Note: there are edge cases where this should work but doesn't;
251292
// this is currently "wontfix":
252293
// https://github.com/go-sql-driver/mysql/issues/184
253-
authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
294+
authResp := append(scrambleOldPassword(authData[:8], password), 0)
254295
return authResp, nil
255296

256297
case "mysql_clear_password":
@@ -259,24 +300,24 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
259300
}
260301
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
261302
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
262-
return append([]byte(mc.cfg.Passwd), 0), nil
303+
return append([]byte(password), 0), nil
263304

264305
case "mysql_native_password":
265306
if !mc.cfg.AllowNativePasswords {
266307
return nil, ErrNativePassword
267308
}
268309
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
269310
// Native password authentication only need and will need 20-byte challenge.
270-
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
311+
authResp := scramblePassword(authData[:20], password)
271312
return authResp, nil
272313

273314
case "sha256_password":
274-
if len(mc.cfg.Passwd) == 0 {
315+
if len(password) == 0 {
275316
return []byte{0}, nil
276317
}
277318
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
278319
// write cleartext auth packet
279-
return append([]byte(mc.cfg.Passwd), 0), nil
320+
return append([]byte(password), 0), nil
280321
}
281322

282323
pubKey := mc.cfg.pubKey
@@ -286,7 +327,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
286327
}
287328

288329
// encrypted password
289-
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
330+
enc, err := encryptPassword(password, authData, pubKey)
290331
return enc, err
291332

292333
default:
@@ -295,7 +336,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
295336
}
296337
}
297338

298-
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
339+
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string, password string) error {
299340
// Read Result Packet
300341
authData, newPlugin, err := mc.readAuthResult()
301342
if err != nil {
@@ -315,7 +356,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
315356

316357
plugin = newPlugin
317358

318-
authResp, err := mc.auth(authData, plugin)
359+
authResp, err := mc.auth(authData, plugin, password)
319360
if err != nil {
320361
return err
321362
}
@@ -352,7 +393,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
352393
case cachingSha2PasswordPerformFullAuthentication:
353394
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
354395
// write cleartext auth packet
355-
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
396+
err = mc.writeAuthSwitchPacket(append([]byte(password), 0))
356397
if err != nil {
357398
return err
358399
}

0 commit comments

Comments
 (0)