@@ -10,9 +10,14 @@ package mysql
10
10
11
11
import (
12
12
"bytes"
13
+ "crypto/rand"
14
+ "crypto/rsa"
15
+ "crypto/sha1"
13
16
"crypto/tls"
17
+ "crypto/x509"
14
18
"database/sql/driver"
15
19
"encoding/binary"
20
+ "encoding/pem"
16
21
"errors"
17
22
"fmt"
18
23
"io"
@@ -154,24 +159,24 @@ func (mc *mysqlConn) writePacket(data []byte) error {
154
159
155
160
// Handshake Initialization Packet
156
161
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
157
- func (mc * mysqlConn ) readInitPacket () ([]byte , error ) {
162
+ func (mc * mysqlConn ) readInitPacket () ([]byte , string , error ) {
158
163
data , err := mc .readPacket ()
159
164
if err != nil {
160
165
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
161
166
// in connection initialization we don't risk retrying non-idempotent actions.
162
167
if err == ErrInvalidConn {
163
- return nil , driver .ErrBadConn
168
+ return nil , "" , driver .ErrBadConn
164
169
}
165
- return nil , err
170
+ return nil , "" , err
166
171
}
167
172
168
173
if data [0 ] == iERR {
169
- return nil , mc .handleErrorPacket (data )
174
+ return nil , "" , mc .handleErrorPacket (data )
170
175
}
171
176
172
177
// protocol version [1 byte]
173
178
if data [0 ] < minProtocolVersion {
174
- return nil , fmt .Errorf (
179
+ return nil , "" , fmt .Errorf (
175
180
"unsupported protocol version %d. Version %d or higher is required" ,
176
181
data [0 ],
177
182
minProtocolVersion ,
@@ -191,13 +196,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
191
196
// capability flags (lower 2 bytes) [2 bytes]
192
197
mc .flags = clientFlag (binary .LittleEndian .Uint16 (data [pos : pos + 2 ]))
193
198
if mc .flags & clientProtocol41 == 0 {
194
- return nil , ErrOldProtocol
199
+ return nil , "" , ErrOldProtocol
195
200
}
196
201
if mc .flags & clientSSL == 0 && mc .cfg .tls != nil {
197
- return nil , ErrNoTLS
202
+ return nil , "" , ErrNoTLS
198
203
}
199
204
pos += 2
200
205
206
+ pluginName := ""
201
207
if len (data ) > pos {
202
208
// character set [1 byte]
203
209
// status flags [2 bytes]
@@ -219,6 +225,8 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
219
225
// The official Python library uses the fixed length 12
220
226
// which seems to work but technically could have a hidden bug.
221
227
cipher = append (cipher , data [pos :pos + 12 ]... )
228
+ pos += 13
229
+ pluginName = string (data [pos : pos + bytes .IndexByte (data [pos :], 0x00 )])
222
230
223
231
// TODO: Verify string termination
224
232
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
@@ -232,18 +240,22 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
232
240
// make a memory safe copy of the cipher slice
233
241
var b [20 ]byte
234
242
copy (b [:], cipher )
235
- return b [:], nil
243
+ return b [:], pluginName , nil
236
244
}
237
245
238
246
// make a memory safe copy of the cipher slice
239
247
var b [8 ]byte
240
248
copy (b [:], cipher )
241
- return b [:], nil
249
+ return b [:], pluginName , nil
242
250
}
243
251
244
252
// Client Authentication Packet
245
253
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
246
- func (mc * mysqlConn ) writeAuthPacket (cipher []byte ) error {
254
+ func (mc * mysqlConn ) writeAuthPacket (cipher []byte , pluginName string ) error {
255
+ if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" {
256
+ return fmt .Errorf ("unknown authentication plugin name '%s'" , pluginName )
257
+ }
258
+
247
259
// Adjust client flags based on server support
248
260
clientFlags := clientProtocol41 |
249
261
clientSecureConn |
@@ -268,7 +280,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
268
280
}
269
281
270
282
// User Password
271
- scrambleBuff := scramblePassword (cipher , []byte (mc .cfg .Passwd ))
283
+ var scrambleBuff []byte
284
+ switch pluginName {
285
+ case "mysql_native_password" :
286
+ scrambleBuff = scramblePassword (cipher , []byte (mc .cfg .Passwd ))
287
+ case "caching_sha2_password" :
288
+ scrambleBuff = scrambleCachingSha2Password (cipher , []byte (mc .cfg .Passwd ))
289
+ }
272
290
273
291
pktLen := 4 + 4 + 1 + 23 + len (mc .cfg .User ) + 1 + 1 + len (scrambleBuff ) + 21 + 1
274
292
@@ -350,7 +368,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
350
368
}
351
369
352
370
// Assume native client during response
353
- pos += copy (data [pos :], "mysql_native_password" )
371
+ pos += copy (data [pos :], pluginName )
354
372
data [pos ] = 0x00
355
373
356
374
// Send Auth packet
@@ -422,6 +440,38 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
422
440
return mc .writePacket (data )
423
441
}
424
442
443
+ // Caching sha2 authentication. Public key request and send encrypted password
444
+ // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
445
+ func (mc * mysqlConn ) writePublicKeyAuthPacket (cipher []byte ) error {
446
+ // request public key
447
+ data := mc .buf .takeSmallBuffer (4 + 1 )
448
+ data [4 ] = cachingSha2PasswordRequestPublicKey
449
+ mc .writePacket (data )
450
+
451
+ data , err := mc .readPacket ()
452
+ if err != nil {
453
+ return err
454
+ }
455
+
456
+ block , _ := pem .Decode (data [1 :])
457
+ pub , err := x509 .ParsePKIXPublicKey (block .Bytes )
458
+ if err != nil {
459
+ return err
460
+ }
461
+
462
+ plain := make ([]byte , len (mc .cfg .Passwd )+ 1 )
463
+ copy (plain , mc .cfg .Passwd )
464
+ for i := range plain {
465
+ j := i % len (cipher )
466
+ plain [i ] ^= cipher [j ]
467
+ }
468
+ sha1 := sha1 .New ()
469
+ enc , _ := rsa .EncryptOAEP (sha1 , rand .Reader , pub .(* rsa.PublicKey ), plain , nil )
470
+ data = mc .buf .takeSmallBuffer (4 + len (enc ))
471
+ copy (data [4 :], enc )
472
+ return mc .writePacket (data )
473
+ }
474
+
425
475
/******************************************************************************
426
476
* Command Packets *
427
477
******************************************************************************/
@@ -535,6 +585,16 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
535
585
return nil , err
536
586
}
537
587
588
+ func (mc * mysqlConn ) readCachingSha2PasswordAuthResult () (int , error ) {
589
+ data , err := mc .readPacket ()
590
+ if err == nil {
591
+ if data [0 ] != 1 {
592
+ return 0 , ErrMalformPkt
593
+ }
594
+ }
595
+ return int (data [1 ]), err
596
+ }
597
+
538
598
// Result Set Header Packet
539
599
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
540
600
func (mc * mysqlConn ) readResultSetHeaderPacket () (int , error ) {
0 commit comments