Skip to content

Commit 88a5a80

Browse files
committed
Add NewCredential func
encapsulate the creation of credentials to simplify downstream implementation publicise the password and plugin so downstream implementations are able to access them
1 parent 4e6aea2 commit 88a5a80

File tree

4 files changed

+50
-40
lines changed

4 files changed

+50
-40
lines changed

server/auth.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ var (
1919
)
2020

2121
func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) error {
22-
if authPluginName != c.credential.authPluginName {
23-
err := c.writeAuthSwitchRequest(c.credential.authPluginName)
22+
if authPluginName != c.credential.AuthPluginName {
23+
err := c.writeAuthSwitchRequest(c.credential.AuthPluginName)
2424
if err != nil {
2525
return err
2626
}
@@ -60,7 +60,7 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err
6060
}
6161

6262
func (c *Conn) acquirePassword() error {
63-
if c.credential.password != "" {
63+
if c.credential.Password != "" {
6464
return nil
6565
}
6666
credential, found, err := c.credentialProvider.GetCredential(c.user)
@@ -75,7 +75,7 @@ func (c *Conn) acquirePassword() error {
7575
}
7676

7777
func errAccessDenied(credential Credential) error {
78-
if credential.password == "" {
78+
if credential.Password == "" {
7979
return ErrAccessDeniedNoPassword
8080
}
8181

@@ -103,7 +103,7 @@ func scrambleValidation(cached, nonce, scramble []byte) bool {
103103
}
104104

105105
func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, credential Credential) error {
106-
password, err := mysql.DecodePasswordHex(c.credential.password)
106+
password, err := mysql.DecodePasswordHex(c.credential.Password)
107107
if err != nil {
108108
return errAccessDenied(credential)
109109
}
@@ -116,7 +116,7 @@ func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, credential C
116116
func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential Credential) error {
117117
// Empty passwords are not hashed, but sent as empty string
118118
if len(clientAuthData) == 0 {
119-
if credential.password == "" {
119+
if credential.Password == "" {
120120
return nil
121121
}
122122
return ErrAccessDenied
@@ -142,7 +142,7 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential C
142142
clientAuthData = clientAuthData[:l-1]
143143
}
144144
}
145-
check, err := mysql.Check256HashingPassword([]byte(credential.password), string(clientAuthData))
145+
check, err := mysql.Check256HashingPassword([]byte(credential.Password), string(clientAuthData))
146146
if err != nil {
147147
return err
148148
}
@@ -155,7 +155,7 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential C
155155
func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error {
156156
// Empty passwords are not hashed, but sent as empty string
157157
if len(clientAuthData) == 0 {
158-
if c.credential.password == "" {
158+
if c.credential.Password == "" {
159159
return nil
160160
}
161161
return ErrAccessDenied

server/auth_switch_response.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error {
7171
}
7272

7373
func (c *Conn) checkSha2CacheCredentials(clientAuthData []byte, credential Credential) error {
74-
match, err := auth.CheckHashingPassword([]byte(credential.password), string(clientAuthData), mysql.AUTH_CACHING_SHA2_PASSWORD)
74+
match, err := auth.CheckHashingPassword([]byte(credential.Password), string(clientAuthData), mysql.AUTH_CACHING_SHA2_PASSWORD)
7575
if match && err == nil {
7676
return nil
7777
}

server/credential_provider.go

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,38 @@ func NewInMemoryProvider(defaultAuthMethod ...string) *InMemoryProvider {
3434
}
3535

3636
type Credential struct {
37-
password string
38-
authPluginName string
37+
Password string
38+
AuthPluginName string
39+
}
40+
41+
func NewCredential(password string, authPluginName string) (Credential, error) {
42+
c := Credential{
43+
AuthPluginName: authPluginName,
44+
}
45+
46+
if password == "" {
47+
c.Password = ""
48+
return c, nil
49+
}
50+
51+
switch c.AuthPluginName {
52+
case mysql.AUTH_NATIVE_PASSWORD:
53+
c.Password = mysql.EncodePasswordHex(mysql.NativePasswordHash([]byte(password)))
54+
55+
case mysql.AUTH_CACHING_SHA2_PASSWORD:
56+
c.Password = auth.NewHashPassword(password, mysql.AUTH_CACHING_SHA2_PASSWORD)
57+
58+
case mysql.AUTH_SHA256_PASSWORD:
59+
hash, err := mysql.NewSha256PasswordHash(password)
60+
if err != nil {
61+
return c, err
62+
}
63+
c.Password = hash
64+
65+
default:
66+
return c, errors.Errorf("unknown authentication plugin name '%s'", c.AuthPluginName)
67+
}
68+
return c, nil
3969
}
4070

4171
// implements an in memory credential provider
@@ -61,37 +91,17 @@ func (m *InMemoryProvider) GetCredential(username string) (credential Credential
6191
return c, true, nil
6292
}
6393

64-
func (m *InMemoryProvider) AddUser(username, password string, authPluginName ...string) error {
65-
c := Credential{
66-
authPluginName: m.defaultAuthMethod,
67-
}
68-
if len(authPluginName) > 0 {
69-
c.authPluginName = authPluginName[0]
94+
func (m *InMemoryProvider) AddUser(username, password string, optionalAuthPluginName ...string) error {
95+
authPluginName := m.defaultAuthMethod
96+
if len(optionalAuthPluginName) > 0 {
97+
authPluginName = optionalAuthPluginName[0]
7098
}
7199

72-
if password == "" {
73-
c.password = ""
74-
m.userPool.Store(username, c)
75-
return nil
100+
c, err := NewCredential(password, authPluginName)
101+
if err != nil {
102+
return err
76103
}
77104

78-
switch c.authPluginName {
79-
case mysql.AUTH_NATIVE_PASSWORD:
80-
c.password = mysql.EncodePasswordHex(mysql.NativePasswordHash([]byte(password)))
81-
82-
case mysql.AUTH_CACHING_SHA2_PASSWORD:
83-
c.password = auth.NewHashPassword(password, mysql.AUTH_CACHING_SHA2_PASSWORD)
84-
85-
case mysql.AUTH_SHA256_PASSWORD:
86-
hash, err := mysql.NewSha256PasswordHash(password)
87-
if err != nil {
88-
return err
89-
}
90-
c.password = hash
91-
92-
default:
93-
return errors.Errorf("unknown authentication plugin name '%s'", c.authPluginName)
94-
}
95105
m.userPool.Store(username, c)
96106
return nil
97107
}

server/handshake_resp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ func (c *Conn) handleAuthMatch() (bool, error) {
204204
return false, err
205205
}
206206

207-
if c.authPluginName != c.credential.authPluginName {
208-
if err := c.writeAuthSwitchRequest(c.credential.authPluginName); err != nil {
207+
if c.authPluginName != c.credential.AuthPluginName {
208+
if err := c.writeAuthSwitchRequest(c.credential.AuthPluginName); err != nil {
209209
return false, err
210210
}
211211
// handle AuthSwitchResponse

0 commit comments

Comments
 (0)