diff --git a/AUTHORS b/AUTHORS index 14e8398fd..2f3a8d68f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -29,6 +29,7 @@ Egor Smolyakov Evan Shaw Frederick Mayle Gustavo Kristic +Hajime Nakagami Hanno Braun Henri Yandell Hirotaka Yamamoto diff --git a/const.go b/const.go index 4a19ca523..1503f9e62 100644 --- a/const.go +++ b/const.go @@ -164,3 +164,9 @@ const ( statusInTransReadonly statusSessionStateChanged ) + +const ( + cachingSha2PasswordRequestPublicKey = 2 + cachingSha2PasswordFastAuthSuccess = 3 + cachingSha2PasswordPerformFullAuthentication = 4 +) diff --git a/driver.go b/driver.go index 27cf5ad4e..f77b917a8 100644 --- a/driver.go +++ b/driver.go @@ -107,20 +107,20 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet - cipher, err := mc.readInitPacket() + cipher, pluginName, err := mc.readInitPacket() if err != nil { mc.cleanup() return nil, err } // Send Client Authentication Packet - if err = mc.writeAuthPacket(cipher); err != nil { + if err = mc.writeAuthPacket(cipher, pluginName); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible - if err = handleAuthResult(mc, cipher); err != nil { + if err = handleAuthResult(mc, cipher, pluginName); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. @@ -153,7 +153,27 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return mc, nil } -func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { +func handleAuthResult(mc *mysqlConn, oldCipher []byte, pluginName string) error { + + // handle caching_sha2_password + if pluginName == "caching_sha2_password" { + auth, err := mc.readCachingSha2PasswordAuthResult() + if err != nil { + return err + } + if auth == cachingSha2PasswordPerformFullAuthentication { + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + if err = mc.writeClearAuthPacket(); err != nil { + return err + } + } else { + if err = mc.writePublicKeyAuthPacket(oldCipher); err != nil { + return err + } + } + } + } + // Read Result Packet cipher, err := mc.readResultOK() if err == nil { diff --git a/driver_test.go b/driver_test.go index ad93a37c3..8904b6587 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1842,7 +1842,7 @@ func TestSQLInjection(t *testing.T) { dsns := []string{ dsn, - dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'", + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", } for _, testdsn := range dsns { runTests(t, testdsn, createTest("1 OR 1=1")) @@ -1872,7 +1872,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) { dsns := []string{ dsn, - dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'", + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", } for _, testdsn := range dsns { runTests(t, testdsn, testData) diff --git a/packets.go b/packets.go index afc3fcc46..6775d2860 100644 --- a/packets.go +++ b/packets.go @@ -10,9 +10,14 @@ package mysql import ( "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" "crypto/tls" + "crypto/x509" "database/sql/driver" "encoding/binary" + "encoding/pem" "errors" "fmt" "io" @@ -154,24 +159,24 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() ([]byte, error) { +func (mc *mysqlConn) readInitPacket() ([]byte, string, error) { data, err := mc.readPacket() if err != nil { // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since // in connection initialization we don't risk retrying non-idempotent actions. if err == ErrInvalidConn { - return nil, driver.ErrBadConn + return nil, "", driver.ErrBadConn } - return nil, err + return nil, "", err } if data[0] == iERR { - return nil, mc.handleErrorPacket(data) + return nil, "", mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, fmt.Errorf( + return nil, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -191,13 +196,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // capability flags (lower 2 bytes) [2 bytes] mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) if mc.flags&clientProtocol41 == 0 { - return nil, ErrOldProtocol + return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, ErrNoTLS + return nil, "", ErrNoTLS } pos += 2 + pluginName := "" if len(data) > pos { // character set [1 byte] // status flags [2 bytes] @@ -219,6 +225,8 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // The official Python library uses the fixed length 12 // which seems to work but technically could have a hidden bug. cipher = append(cipher, data[pos:pos+12]...) + pos += 13 + pluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) @@ -232,18 +240,22 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // make a memory safe copy of the cipher slice var b [20]byte copy(b[:], cipher) - return b[:], nil + return b[:], pluginName, nil } // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], cipher) - return b[:], nil + return b[:], pluginName, nil } // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error { + if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" { + return fmt.Errorf("unknown authentication plugin name '%s'", pluginName) + } + // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -268,7 +280,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // User Password - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) + var scrambleBuff []byte + switch pluginName { + case "mysql_native_password": + scrambleBuff = scramblePassword(cipher, []byte(mc.cfg.Passwd)) + case "caching_sha2_password": + scrambleBuff = scrambleCachingSha2Password(cipher, []byte(mc.cfg.Passwd)) + } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 @@ -350,7 +368,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // Assume native client during response - pos += copy(data[pos:], "mysql_native_password") + pos += copy(data[pos:], pluginName) data[pos] = 0x00 // Send Auth packet @@ -422,6 +440,38 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { return mc.writePacket(data) } +// Caching sha2 authentication. Public key request and send encrypted password +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writePublicKeyAuthPacket(cipher []byte) error { + // request public key + data := mc.buf.takeSmallBuffer(4 + 1) + data[4] = cachingSha2PasswordRequestPublicKey + mc.writePacket(data) + + data, err := mc.readPacket() + if err != nil { + return err + } + + block, _ := pem.Decode(data[1:]) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + + plain := make([]byte, len(mc.cfg.Passwd)+1) + copy(plain, mc.cfg.Passwd) + for i := range plain { + j := i % len(cipher) + plain[i] ^= cipher[j] + } + sha1 := sha1.New() + enc, _ := rsa.EncryptOAEP(sha1, rand.Reader, pub.(*rsa.PublicKey), plain, nil) + data = mc.buf.takeSmallBuffer(4 + len(enc)) + copy(data[4:], enc) + return mc.writePacket(data) +} + /****************************************************************************** * Command Packets * ******************************************************************************/ @@ -535,6 +585,16 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) { return nil, err } +func (mc *mysqlConn) readCachingSha2PasswordAuthResult() (int, error) { + data, err := mc.readPacket() + if err == nil { + if data[0] != 1 { + return 0, ErrMalformPkt + } + } + return int(data[1]), err +} + // Result Set Header Packet // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { diff --git a/utils.go b/utils.go index f986de2ab..9d1530b3d 100644 --- a/utils.go +++ b/utils.go @@ -10,6 +10,7 @@ package mysql import ( "crypto/sha1" + "crypto/sha256" "crypto/tls" "database/sql/driver" "encoding/binary" @@ -211,6 +212,34 @@ func scrambleOldPassword(scramble, password []byte) []byte { return out[:] } +// Encrypt password using 8.0 default method +func scrambleCachingSha2Password(scramble, password []byte) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write(password) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + /****************************************************************************** * Time related utils * ******************************************************************************/ diff --git a/utils_test.go b/utils_test.go index 0041892db..a599c55f3 100644 --- a/utils_test.go +++ b/utils_test.go @@ -112,6 +112,24 @@ func TestOldPass(t *testing.T) { } } +func TestCachingSha2Pass(t *testing.T) { + scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21} + vectors := []struct { + pass string + out string + }{ + {"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"}, + {"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"}, + } + for _, tuple := range vectors { + ours := scrambleCachingSha2Password(scramble, []byte(tuple.pass)) + if tuple.out != fmt.Sprintf("%x", ours) { + t.Errorf("Failed caching sha2 password %q", tuple.pass) + } + } + +} + func TestFormatBinaryDateTime(t *testing.T) { rawDate := [11]byte{} binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years