Skip to content

Commit f557730

Browse files
nakagamijulienschmidt
authored andcommitted
Support caching_sha2_password (#794)
1 parent 3287d94 commit f557730

File tree

7 files changed

+152
-18
lines changed

7 files changed

+152
-18
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Egor Smolyakov <egorsmkv at gmail.com>
2929
Evan Shaw <evan at vendhq.com>
3030
Frederick Mayle <frederickmayle at gmail.com>
3131
Gustavo Kristic <gkristic at gmail.com>
32+
Hajime Nakagami <nakagami at gmail.com>
3233
Hanno Braun <mail at hannobraun.com>
3334
Henri Yandell <flamefew at gmail.com>
3435
Hirotaka Yamamoto <ymmt2005 at gmail.com>

const.go

+6
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,9 @@ const (
164164
statusInTransReadonly
165165
statusSessionStateChanged
166166
)
167+
168+
const (
169+
cachingSha2PasswordRequestPublicKey = 2
170+
cachingSha2PasswordFastAuthSuccess = 3
171+
cachingSha2PasswordPerformFullAuthentication = 4
172+
)

driver.go

+24-4
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,20 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
107107
mc.writeTimeout = mc.cfg.WriteTimeout
108108

109109
// Reading Handshake Initialization Packet
110-
cipher, err := mc.readInitPacket()
110+
cipher, pluginName, err := mc.readInitPacket()
111111
if err != nil {
112112
mc.cleanup()
113113
return nil, err
114114
}
115115

116116
// Send Client Authentication Packet
117-
if err = mc.writeAuthPacket(cipher); err != nil {
117+
if err = mc.writeAuthPacket(cipher, pluginName); err != nil {
118118
mc.cleanup()
119119
return nil, err
120120
}
121121

122122
// Handle response to auth packet, switch methods if possible
123-
if err = handleAuthResult(mc, cipher); err != nil {
123+
if err = handleAuthResult(mc, cipher, pluginName); err != nil {
124124
// Authentication failed and MySQL has already closed the connection
125125
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
126126
// Do not send COM_QUIT, just cleanup and return the error.
@@ -153,7 +153,27 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
153153
return mc, nil
154154
}
155155

156-
func handleAuthResult(mc *mysqlConn, oldCipher []byte) error {
156+
func handleAuthResult(mc *mysqlConn, oldCipher []byte, pluginName string) error {
157+
158+
// handle caching_sha2_password
159+
if pluginName == "caching_sha2_password" {
160+
auth, err := mc.readCachingSha2PasswordAuthResult()
161+
if err != nil {
162+
return err
163+
}
164+
if auth == cachingSha2PasswordPerformFullAuthentication {
165+
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
166+
if err = mc.writeClearAuthPacket(); err != nil {
167+
return err
168+
}
169+
} else {
170+
if err = mc.writePublicKeyAuthPacket(oldCipher); err != nil {
171+
return err
172+
}
173+
}
174+
}
175+
}
176+
157177
// Read Result Packet
158178
cipher, err := mc.readResultOK()
159179
if err == nil {

driver_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -1842,7 +1842,7 @@ func TestSQLInjection(t *testing.T) {
18421842

18431843
dsns := []string{
18441844
dsn,
1845-
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
1845+
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
18461846
}
18471847
for _, testdsn := range dsns {
18481848
runTests(t, testdsn, createTest("1 OR 1=1"))
@@ -1872,7 +1872,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) {
18721872

18731873
dsns := []string{
18741874
dsn,
1875-
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
1875+
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
18761876
}
18771877
for _, testdsn := range dsns {
18781878
runTests(t, testdsn, testData)

packets.go

+72-12
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,14 @@ package mysql
1010

1111
import (
1212
"bytes"
13+
"crypto/rand"
14+
"crypto/rsa"
15+
"crypto/sha1"
1316
"crypto/tls"
17+
"crypto/x509"
1418
"database/sql/driver"
1519
"encoding/binary"
20+
"encoding/pem"
1621
"errors"
1722
"fmt"
1823
"io"
@@ -154,24 +159,24 @@ func (mc *mysqlConn) writePacket(data []byte) error {
154159

155160
// Handshake Initialization Packet
156161
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
157-
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
162+
func (mc *mysqlConn) readInitPacket() ([]byte, string, error) {
158163
data, err := mc.readPacket()
159164
if err != nil {
160165
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
161166
// in connection initialization we don't risk retrying non-idempotent actions.
162167
if err == ErrInvalidConn {
163-
return nil, driver.ErrBadConn
168+
return nil, "", driver.ErrBadConn
164169
}
165-
return nil, err
170+
return nil, "", err
166171
}
167172

168173
if data[0] == iERR {
169-
return nil, mc.handleErrorPacket(data)
174+
return nil, "", mc.handleErrorPacket(data)
170175
}
171176

172177
// protocol version [1 byte]
173178
if data[0] < minProtocolVersion {
174-
return nil, fmt.Errorf(
179+
return nil, "", fmt.Errorf(
175180
"unsupported protocol version %d. Version %d or higher is required",
176181
data[0],
177182
minProtocolVersion,
@@ -191,13 +196,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
191196
// capability flags (lower 2 bytes) [2 bytes]
192197
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
193198
if mc.flags&clientProtocol41 == 0 {
194-
return nil, ErrOldProtocol
199+
return nil, "", ErrOldProtocol
195200
}
196201
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
197-
return nil, ErrNoTLS
202+
return nil, "", ErrNoTLS
198203
}
199204
pos += 2
200205

206+
pluginName := ""
201207
if len(data) > pos {
202208
// character set [1 byte]
203209
// status flags [2 bytes]
@@ -219,6 +225,8 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
219225
// The official Python library uses the fixed length 12
220226
// which seems to work but technically could have a hidden bug.
221227
cipher = append(cipher, data[pos:pos+12]...)
228+
pos += 13
229+
pluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
222230

223231
// TODO: Verify string termination
224232
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
@@ -232,18 +240,22 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
232240
// make a memory safe copy of the cipher slice
233241
var b [20]byte
234242
copy(b[:], cipher)
235-
return b[:], nil
243+
return b[:], pluginName, nil
236244
}
237245

238246
// make a memory safe copy of the cipher slice
239247
var b [8]byte
240248
copy(b[:], cipher)
241-
return b[:], nil
249+
return b[:], pluginName, nil
242250
}
243251

244252
// Client Authentication Packet
245253
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
246-
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
254+
func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
255+
if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" {
256+
return fmt.Errorf("unknown authentication plugin name '%s'", pluginName)
257+
}
258+
247259
// Adjust client flags based on server support
248260
clientFlags := clientProtocol41 |
249261
clientSecureConn |
@@ -268,7 +280,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
268280
}
269281

270282
// User Password
271-
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
283+
var scrambleBuff []byte
284+
switch pluginName {
285+
case "mysql_native_password":
286+
scrambleBuff = scramblePassword(cipher, []byte(mc.cfg.Passwd))
287+
case "caching_sha2_password":
288+
scrambleBuff = scrambleCachingSha2Password(cipher, []byte(mc.cfg.Passwd))
289+
}
272290

273291
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
274292

@@ -350,7 +368,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
350368
}
351369

352370
// Assume native client during response
353-
pos += copy(data[pos:], "mysql_native_password")
371+
pos += copy(data[pos:], pluginName)
354372
data[pos] = 0x00
355373

356374
// Send Auth packet
@@ -422,6 +440,38 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
422440
return mc.writePacket(data)
423441
}
424442

443+
// Caching sha2 authentication. Public key request and send encrypted password
444+
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
445+
func (mc *mysqlConn) writePublicKeyAuthPacket(cipher []byte) error {
446+
// request public key
447+
data := mc.buf.takeSmallBuffer(4 + 1)
448+
data[4] = cachingSha2PasswordRequestPublicKey
449+
mc.writePacket(data)
450+
451+
data, err := mc.readPacket()
452+
if err != nil {
453+
return err
454+
}
455+
456+
block, _ := pem.Decode(data[1:])
457+
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
458+
if err != nil {
459+
return err
460+
}
461+
462+
plain := make([]byte, len(mc.cfg.Passwd)+1)
463+
copy(plain, mc.cfg.Passwd)
464+
for i := range plain {
465+
j := i % len(cipher)
466+
plain[i] ^= cipher[j]
467+
}
468+
sha1 := sha1.New()
469+
enc, _ := rsa.EncryptOAEP(sha1, rand.Reader, pub.(*rsa.PublicKey), plain, nil)
470+
data = mc.buf.takeSmallBuffer(4 + len(enc))
471+
copy(data[4:], enc)
472+
return mc.writePacket(data)
473+
}
474+
425475
/******************************************************************************
426476
* Command Packets *
427477
******************************************************************************/
@@ -535,6 +585,16 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
535585
return nil, err
536586
}
537587

588+
func (mc *mysqlConn) readCachingSha2PasswordAuthResult() (int, error) {
589+
data, err := mc.readPacket()
590+
if err == nil {
591+
if data[0] != 1 {
592+
return 0, ErrMalformPkt
593+
}
594+
}
595+
return int(data[1]), err
596+
}
597+
538598
// Result Set Header Packet
539599
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
540600
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {

utils.go

+29
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"crypto/sha1"
13+
"crypto/sha256"
1314
"crypto/tls"
1415
"database/sql/driver"
1516
"encoding/binary"
@@ -211,6 +212,34 @@ func scrambleOldPassword(scramble, password []byte) []byte {
211212
return out[:]
212213
}
213214

215+
// Encrypt password using 8.0 default method
216+
func scrambleCachingSha2Password(scramble, password []byte) []byte {
217+
if len(password) == 0 {
218+
return nil
219+
}
220+
221+
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
222+
223+
crypt := sha256.New()
224+
crypt.Write(password)
225+
message1 := crypt.Sum(nil)
226+
227+
crypt.Reset()
228+
crypt.Write(message1)
229+
message1Hash := crypt.Sum(nil)
230+
231+
crypt.Reset()
232+
crypt.Write(message1Hash)
233+
crypt.Write(scramble)
234+
message2 := crypt.Sum(nil)
235+
236+
for i := range message1 {
237+
message1[i] ^= message2[i]
238+
}
239+
240+
return message1
241+
}
242+
214243
/******************************************************************************
215244
* Time related utils *
216245
******************************************************************************/

utils_test.go

+18
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,24 @@ func TestOldPass(t *testing.T) {
112112
}
113113
}
114114

115+
func TestCachingSha2Pass(t *testing.T) {
116+
scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21}
117+
vectors := []struct {
118+
pass string
119+
out string
120+
}{
121+
{"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"},
122+
{"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"},
123+
}
124+
for _, tuple := range vectors {
125+
ours := scrambleCachingSha2Password(scramble, []byte(tuple.pass))
126+
if tuple.out != fmt.Sprintf("%x", ours) {
127+
t.Errorf("Failed caching sha2 password %q", tuple.pass)
128+
}
129+
}
130+
131+
}
132+
115133
func TestFormatBinaryDateTime(t *testing.T) {
116134
rawDate := [11]byte{}
117135
binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years

0 commit comments

Comments
 (0)