Skip to content

Commit 9db718f

Browse files
committed
simplify
1 parent 1907645 commit 9db718f

File tree

5 files changed

+97
-111
lines changed

5 files changed

+97
-111
lines changed

auth_test.go

+14-14
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
8989
if err != nil {
9090
t.Fatal(err)
9191
}
92-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
92+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
9393
if err != nil {
9494
t.Fatal(err)
9595
}
@@ -134,7 +134,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
134134
if err != nil {
135135
t.Fatal(err)
136136
}
137-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
137+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
138138
if err != nil {
139139
t.Fatal(err)
140140
}
@@ -176,7 +176,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
176176
if err != nil {
177177
t.Fatal(err)
178178
}
179-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
179+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
180180
if err != nil {
181181
t.Fatal(err)
182182
}
@@ -232,7 +232,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
232232
if err != nil {
233233
t.Fatal(err)
234234
}
235-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
235+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
236236
if err != nil {
237237
t.Fatal(err)
238238
}
@@ -284,7 +284,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
284284
if err != nil {
285285
t.Fatal(err)
286286
}
287-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
287+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
288288
if err != nil {
289289
t.Fatal(err)
290290
}
@@ -357,7 +357,7 @@ func TestAuthFastCleartextPassword(t *testing.T) {
357357
if err != nil {
358358
t.Fatal(err)
359359
}
360-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
360+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
361361
if err != nil {
362362
t.Fatal(err)
363363
}
@@ -400,7 +400,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
400400
if err != nil {
401401
t.Fatal(err)
402402
}
403-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
403+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
404404
if err != nil {
405405
t.Fatal(err)
406406
}
@@ -459,7 +459,7 @@ func TestAuthFastNativePassword(t *testing.T) {
459459
if err != nil {
460460
t.Fatal(err)
461461
}
462-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
462+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
463463
if err != nil {
464464
t.Fatal(err)
465465
}
@@ -502,7 +502,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) {
502502
if err != nil {
503503
t.Fatal(err)
504504
}
505-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
505+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
506506
if err != nil {
507507
t.Fatal(err)
508508
}
@@ -544,7 +544,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
544544
if err != nil {
545545
t.Fatal(err)
546546
}
547-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
547+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
548548
if err != nil {
549549
t.Fatal(err)
550550
}
@@ -592,7 +592,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) {
592592
if err != nil {
593593
t.Fatal(err)
594594
}
595-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
595+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
596596
if err != nil {
597597
t.Fatal(err)
598598
}
@@ -641,7 +641,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
641641
if err != nil {
642642
t.Fatal(err)
643643
}
644-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
644+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
645645
if err != nil {
646646
t.Fatal(err)
647647
}
@@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
678678
// unset TLS config to prevent the actual establishment of a TLS wrapper
679679
mc.cfg.TLS = nil
680680

681-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
681+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
682682
if err != nil {
683683
t.Fatal(err)
684684
}
@@ -1343,7 +1343,7 @@ func TestEd25519Auth(t *testing.T) {
13431343
if err != nil {
13441344
t.Fatal(err)
13451345
}
1346-
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
1346+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
13471347
if err != nil {
13481348
t.Fatal(err)
13491349
}

connection.go

+17-17
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,22 @@ import (
2424
)
2525

2626
type mysqlConn struct {
27-
buf buffer
28-
netConn net.Conn
29-
rawConn net.Conn // underlying connection when netConn is TLS connection.
30-
result mysqlResult // managed by clearResult() and handleOkPacket().
31-
compIO *compIO
32-
cfg *Config
33-
connector *connector
34-
maxAllowedPacket int
35-
maxWriteSize int
36-
clientCapabilities capabilityFlag
37-
clientExtCapabilities extendedCapabilityFlag
38-
status statusFlag
39-
sequence uint8
40-
compressSequence uint8
41-
parseTime bool
42-
compress bool
27+
buf buffer
28+
netConn net.Conn
29+
rawConn net.Conn // underlying connection when netConn is TLS connection.
30+
result mysqlResult // managed by clearResult() and handleOkPacket().
31+
compIO *compIO
32+
cfg *Config
33+
connector *connector
34+
maxAllowedPacket int
35+
maxWriteSize int
36+
capabilities capabilityFlag
37+
extCapabilities extendedCapabilityFlag
38+
status statusFlag
39+
sequence uint8
40+
compressSequence uint8
41+
parseTime bool
42+
compress bool
4343

4444
// for context support (Go 1.8+)
4545
watching bool
@@ -230,7 +230,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
230230
}
231231

232232
if columnCount > 0 {
233-
if mc.clientExtCapabilities&clientCacheMetadata != 0 {
233+
if mc.extCapabilities&clientCacheMetadata != 0 {
234234
if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil {
235235
return nil, err
236236
}

connector.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
131131
mc.buf = newBuffer()
132132

133133
// Reading Handshake Initialization Packet
134-
authData, serverCapabilities, serverExtendedCapabilities, plugin, err := mc.readHandshakePacket()
134+
authData, serverCapabilities, serverExtCapabilities, plugin, err := mc.readHandshakePacket()
135135
if err != nil {
136136
mc.cleanup()
137137
return nil, err
@@ -157,7 +157,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
157157
return nil, err
158158
}
159159
}
160-
if err = mc.writeHandshakeResponsePacket(authResp, serverCapabilities, serverExtendedCapabilities, plugin); err != nil {
160+
mc.initCapabilities(serverCapabilities, serverExtCapabilities, mc.cfg)
161+
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
161162
mc.cleanup()
162163
return nil, err
163164
}
@@ -171,7 +172,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
171172
return nil, err
172173
}
173174

174-
if mc.cfg.compress && mc.clientCapabilities&clientCompress > 0 {
175+
// compression is enabled after auth, not right after sending handshake response.
176+
if mc.capabilities&clientCompress > 0 {
175177
mc.compress = true
176178
mc.compIO = newCompIO(mc)
177179
}

const.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ const (
4242
iERR byte = 0xff
4343
)
4444

45-
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
45+
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html
46+
// https://mariadb.com/kb/en/connection/#capabilities
4647
type capabilityFlag uint32
4748

4849
const (

0 commit comments

Comments
 (0)