diff --git a/server/auth.go b/server/auth.go index 0019e8ac2..e2852a385 100644 --- a/server/auth.go +++ b/server/auth.go @@ -13,7 +13,10 @@ import ( "github.com/pingcap/errors" ) -var ErrAccessDenied = errors.New("access denied") +var ( + ErrAccessDenied = errors.New("access denied") + ErrAccessDeniedNoPassword = fmt.Errorf("%w without password", ErrAccessDenied) +) func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) error { switch authPluginName { @@ -62,6 +65,14 @@ func (c *Conn) acquirePassword() error { return nil } +func errAccessDenied(password string) error { + if password == "" { + return ErrAccessDeniedNoPassword + } + + return ErrAccessDenied +} + func scrambleValidation(cached, nonce, scramble []byte) bool { // SHA256(SHA256(SHA256(STORED_PASSWORD)), NONCE) crypt := sha256.New() @@ -83,10 +94,10 @@ func scrambleValidation(cached, nonce, scramble []byte) bool { } func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, password string) error { - if bytes.Equal(CalcPassword(c.salt, []byte(c.password)), clientAuthData) { + if bytes.Equal(CalcPassword(c.salt, []byte(password)), clientAuthData) { return nil } - return ErrAccessDenied + return errAccessDenied(password) } func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password string) error { @@ -109,7 +120,7 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password str if bytes.Equal(clientAuthData, []byte(password)) { return nil } - return ErrAccessDenied + return errAccessDenied(password) } else { // client should send encrypted password // decrypt @@ -126,7 +137,7 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password str if bytes.Equal(plain, dbytes) { return nil } - return ErrAccessDenied + return errAccessDenied(password) } } @@ -153,7 +164,8 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 return c.writeAuthMoreDataFastAuth() } - return ErrAccessDenied + + return errAccessDenied(c.password) } // other type of credential provider, we use the cache cached, ok := c.serverConf.cacheShaPassword.Load(fmt.Sprintf("%s@%s", c.user, c.Conn.LocalAddr())) @@ -163,7 +175,8 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 return c.writeAuthMoreDataFastAuth() } - return ErrAccessDenied + + return errAccessDenied(c.password) } // cache miss, do full auth if err := c.writeAuthMoreDataFullAuth(); err != nil { diff --git a/server/auth_switch_response.go b/server/auth_switch_response.go index 44fea6abf..209598909 100644 --- a/server/auth_switch_response.go +++ b/server/auth_switch_response.go @@ -25,7 +25,7 @@ func (c *Conn) handleAuthSwitchResponse() error { return err } if !bytes.Equal(CalcPassword(c.salt, []byte(c.password)), authData) { - return ErrAccessDenied + return errAccessDenied(c.password) } return nil @@ -82,7 +82,7 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { if bytes.Equal(authData, []byte(c.password)) { return nil } - return ErrAccessDenied + return errAccessDenied(c.password) } else { // client either request for the public key or send the encrypted password if len(authData) == 1 && authData[0] == 0x02 { @@ -111,7 +111,7 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { if bytes.Equal(plain, dbytes) { return nil } - return ErrAccessDenied + return errAccessDenied(c.password) } } diff --git a/server/conn.go b/server/conn.go index 3fcbc7d3f..19bdae8ba 100644 --- a/server/conn.go +++ b/server/conn.go @@ -1,6 +1,7 @@ package server import ( + "errors" "net" "sync/atomic" @@ -105,8 +106,12 @@ func (c *Conn) handshake() error { } if err := c.readHandshakeResponse(); err != nil { - if err == ErrAccessDenied { - err = NewDefaultError(ER_ACCESS_DENIED_ERROR, c.user, c.LocalAddr().String(), "Yes") + if errors.Is(err, ErrAccessDenied) { + usingPasswd := ER_YES + if errors.Is(err, ErrAccessDeniedNoPassword) { + usingPasswd = ER_NO + } + err = NewDefaultError(ER_ACCESS_DENIED_ERROR, c.user, c.RemoteAddr().String(), usingPasswd) } _ = c.writeError(err) return err diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 584226747..f2777bb7e 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -148,7 +148,7 @@ func (c *Conn) readAuthData(data []byte, pos int) (auth []byte, authLen int, new } if isNULL { // no auth length and no auth data, just \NUL, considered invalid auth data, and reject connection as MySQL does - return nil, 0, 0, NewDefaultError(ER_ACCESS_DENIED_ERROR, c.LocalAddr().String(), c.user, "Yes") + return nil, 0, 0, NewDefaultError(ER_ACCESS_DENIED_ERROR, c.RemoteAddr().String(), c.user, ER_NO) } auth = authData authLen = readBytes