Skip to content

Commit ba73da2

Browse files
authored
Merge pull request #589 from skoef/errAccessDenied
Improved access denied error messages
2 parents 735aad5 + 6913157 commit ba73da2

File tree

4 files changed

+31
-13
lines changed

4 files changed

+31
-13
lines changed

server/auth.go

+20-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ import (
1313
"github.com/pingcap/errors"
1414
)
1515

16-
var ErrAccessDenied = errors.New("access denied")
16+
var (
17+
ErrAccessDenied = errors.New("access denied")
18+
ErrAccessDeniedNoPassword = fmt.Errorf("%w without password", ErrAccessDenied)
19+
)
1720

1821
func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) error {
1922
switch authPluginName {
@@ -62,6 +65,14 @@ func (c *Conn) acquirePassword() error {
6265
return nil
6366
}
6467

68+
func errAccessDenied(password string) error {
69+
if password == "" {
70+
return ErrAccessDeniedNoPassword
71+
}
72+
73+
return ErrAccessDenied
74+
}
75+
6576
func scrambleValidation(cached, nonce, scramble []byte) bool {
6677
// SHA256(SHA256(SHA256(STORED_PASSWORD)), NONCE)
6778
crypt := sha256.New()
@@ -83,10 +94,10 @@ func scrambleValidation(cached, nonce, scramble []byte) bool {
8394
}
8495

8596
func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, password string) error {
86-
if bytes.Equal(CalcPassword(c.salt, []byte(c.password)), clientAuthData) {
97+
if bytes.Equal(CalcPassword(c.salt, []byte(password)), clientAuthData) {
8798
return nil
8899
}
89-
return ErrAccessDenied
100+
return errAccessDenied(password)
90101
}
91102

92103
func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password string) error {
@@ -109,7 +120,7 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password str
109120
if bytes.Equal(clientAuthData, []byte(password)) {
110121
return nil
111122
}
112-
return ErrAccessDenied
123+
return errAccessDenied(password)
113124
} else {
114125
// client should send encrypted password
115126
// decrypt
@@ -126,7 +137,7 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password str
126137
if bytes.Equal(plain, dbytes) {
127138
return nil
128139
}
129-
return ErrAccessDenied
140+
return errAccessDenied(password)
130141
}
131142
}
132143

@@ -153,7 +164,8 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error {
153164
// 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03
154165
return c.writeAuthMoreDataFastAuth()
155166
}
156-
return ErrAccessDenied
167+
168+
return errAccessDenied(c.password)
157169
}
158170
// other type of credential provider, we use the cache
159171
cached, ok := c.serverConf.cacheShaPassword.Load(fmt.Sprintf("%s@%s", c.user, c.Conn.LocalAddr()))
@@ -163,7 +175,8 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error {
163175
// 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03
164176
return c.writeAuthMoreDataFastAuth()
165177
}
166-
return ErrAccessDenied
178+
179+
return errAccessDenied(c.password)
167180
}
168181
// cache miss, do full auth
169182
if err := c.writeAuthMoreDataFullAuth(); err != nil {

server/auth_switch_response.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func (c *Conn) handleAuthSwitchResponse() error {
2525
return err
2626
}
2727
if !bytes.Equal(CalcPassword(c.salt, []byte(c.password)), authData) {
28-
return ErrAccessDenied
28+
return errAccessDenied(c.password)
2929
}
3030
return nil
3131

@@ -82,7 +82,7 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error {
8282
if bytes.Equal(authData, []byte(c.password)) {
8383
return nil
8484
}
85-
return ErrAccessDenied
85+
return errAccessDenied(c.password)
8686
} else {
8787
// client either request for the public key or send the encrypted password
8888
if len(authData) == 1 && authData[0] == 0x02 {
@@ -111,7 +111,7 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error {
111111
if bytes.Equal(plain, dbytes) {
112112
return nil
113113
}
114-
return ErrAccessDenied
114+
return errAccessDenied(c.password)
115115
}
116116
}
117117

server/conn.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package server
22

33
import (
4+
"errors"
45
"net"
56
"sync/atomic"
67

@@ -105,8 +106,12 @@ func (c *Conn) handshake() error {
105106
}
106107

107108
if err := c.readHandshakeResponse(); err != nil {
108-
if err == ErrAccessDenied {
109-
err = NewDefaultError(ER_ACCESS_DENIED_ERROR, c.user, c.LocalAddr().String(), "Yes")
109+
if errors.Is(err, ErrAccessDenied) {
110+
usingPasswd := ER_YES
111+
if errors.Is(err, ErrAccessDeniedNoPassword) {
112+
usingPasswd = ER_NO
113+
}
114+
err = NewDefaultError(ER_ACCESS_DENIED_ERROR, c.user, c.RemoteAddr().String(), usingPasswd)
110115
}
111116
_ = c.writeError(err)
112117
return err

server/handshake_resp.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ func (c *Conn) readAuthData(data []byte, pos int) (auth []byte, authLen int, new
148148
}
149149
if isNULL {
150150
// no auth length and no auth data, just \NUL, considered invalid auth data, and reject connection as MySQL does
151-
return nil, 0, 0, NewDefaultError(ER_ACCESS_DENIED_ERROR, c.LocalAddr().String(), c.user, "Yes")
151+
return nil, 0, 0, NewDefaultError(ER_ACCESS_DENIED_ERROR, c.RemoteAddr().String(), c.user, ER_NO)
152152
}
153153
auth = authData
154154
authLen = readBytes

0 commit comments

Comments
 (0)