Skip to content

Commit 372d17b

Browse files
committed
change to another impl
Signed-off-by: lance6716 <[email protected]>
1 parent 2a5ac9b commit 372d17b

File tree

5 files changed

+70
-52
lines changed

5 files changed

+70
-52
lines changed

auth.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
275275
}
276276
// unlike caching_sha2_password, sha256_password does not accept
277277
// cleartext password on unix transport.
278-
if mc.cfg.tls != nil {
278+
if mc.cfg.TLS != nil {
279279
// write cleartext auth packet
280280
return append([]byte(mc.cfg.Passwd), 0), nil
281281
}
@@ -351,7 +351,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
351351
}
352352

353353
case cachingSha2PasswordPerformFullAuthentication:
354-
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
354+
if mc.cfg.TLS != nil || mc.cfg.Net == "unix" {
355355
// write cleartext auth packet
356356
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
357357
if err != nil {

auth_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
291291

292292
// Hack to make the caching_sha2_password plugin believe that the connection
293293
// is secure
294-
mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
294+
mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true}
295295

296296
// check written auth response
297297
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
@@ -663,7 +663,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
663663

664664
// hack to make the caching_sha2_password plugin believe that the connection
665665
// is secure
666-
mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
666+
mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true}
667667

668668
authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
669669
62, 94, 83, 80, 52, 85}
@@ -676,7 +676,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
676676
}
677677

678678
// unset TLS config to prevent the actual establishment of a TLS wrapper
679-
mc.cfg.tls = nil
679+
mc.cfg.TLS = nil
680680

681681
err = mc.writeHandshakeResponsePacket(authResp, plugin)
682682
if err != nil {
@@ -866,7 +866,7 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) {
866866

867867
// Hack to make the caching_sha2_password plugin believe that the connection
868868
// is secure
869-
mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
869+
mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true}
870870

871871
// auth switch request
872872
conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
@@ -1299,7 +1299,7 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) {
12991299

13001300
// Hack to make the caching_sha2_password plugin believe that the connection
13011301
// is secure
1302-
mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
1302+
mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true}
13031303

13041304
// auth switch request
13051305
conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97,

dsn.go

+34-16
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,14 @@ type Config struct {
4646
ServerPubKey string // Server public key name
4747
pubKey *rsa.PublicKey // Server public key
4848
TLSConfig string // TLS configuration name
49-
tls *tls.Config // TLS configuration
49+
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
5050
Timeout time.Duration // Dial timeout
5151
ReadTimeout time.Duration // I/O read timeout
5252
WriteTimeout time.Duration // I/O write timeout
5353

5454
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
5555
AllowCleartextPasswords bool // Allows the cleartext client side plugin
56+
AllowFallbackToNoTLS bool // Allows fallback to unencrypted connection if server does not support TLS
5657
AllowNativePasswords bool // Allows the native password authentication method
5758
AllowOldPasswords bool // Allows the old insecure password method
5859
CheckConnLiveness bool // Check connections for liveness before using them
@@ -77,8 +78,8 @@ func NewConfig() *Config {
7778

7879
func (cfg *Config) Clone() *Config {
7980
cp := *cfg
80-
if cp.tls != nil {
81-
cp.tls = cfg.tls.Clone()
81+
if cp.TLS != nil {
82+
cp.TLS = cfg.TLS.Clone()
8283
}
8384
if len(cp.Params) > 0 {
8485
cp.Params = make(map[string]string, len(cfg.Params))
@@ -119,24 +120,29 @@ func (cfg *Config) normalize() error {
119120
cfg.Addr = ensureHavePort(cfg.Addr)
120121
}
121122

122-
switch cfg.TLSConfig {
123-
case "false", "":
124-
// don't set anything
125-
case "true":
126-
cfg.tls = &tls.Config{}
127-
case "skip-verify", "preferred":
128-
cfg.tls = &tls.Config{InsecureSkipVerify: true}
129-
default:
130-
cfg.tls = getTLSConfigClone(cfg.TLSConfig)
131-
if cfg.tls == nil {
132-
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
123+
if cfg.TLS == nil {
124+
switch cfg.TLSConfig {
125+
case "false", "":
126+
// don't set anything
127+
case "true":
128+
cfg.TLS = &tls.Config{}
129+
case "skip-verify":
130+
cfg.TLS = &tls.Config{InsecureSkipVerify: true}
131+
case "preferred":
132+
cfg.TLS = &tls.Config{InsecureSkipVerify: true}
133+
cfg.AllowFallbackToNoTLS = true
134+
default:
135+
cfg.TLS = getTLSConfigClone(cfg.TLSConfig)
136+
if cfg.TLS == nil {
137+
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
138+
}
133139
}
134140
}
135141

136-
if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
142+
if cfg.TLS != nil && cfg.TLS.ServerName == "" && !cfg.TLS.InsecureSkipVerify {
137143
host, _, err := net.SplitHostPort(cfg.Addr)
138144
if err == nil {
139-
cfg.tls.ServerName = host
145+
cfg.TLS.ServerName = host
140146
}
141147
}
142148

@@ -204,6 +210,10 @@ func (cfg *Config) FormatDSN() string {
204210
writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true")
205211
}
206212

213+
if cfg.AllowFallbackToNoTLS {
214+
writeDSNParam(&buf, &hasParam, "allowFallbackToNoTLS", "true")
215+
}
216+
207217
if !cfg.AllowNativePasswords {
208218
writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false")
209219
}
@@ -391,6 +401,14 @@ func parseDSNParams(cfg *Config, params string) (err error) {
391401
return errors.New("invalid bool value: " + value)
392402
}
393403

404+
// Allow fallback to unencrypted connection if server does not support TLS
405+
case "allowFallbackToNoTLS":
406+
var isBool bool
407+
cfg.AllowFallbackToNoTLS, isBool = readBool(value)
408+
if !isBool {
409+
return errors.New("invalid bool value: " + value)
410+
}
411+
394412
// Use native password authentication
395413
case "allowNativePasswords":
396414
var isBool bool

dsn_test.go

+23-23
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ var testDSNs = []struct {
4242
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true",
4343
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
4444
}, {
45-
"user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0",
46-
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false},
45+
"user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0&allowFallbackToNoTLS=true",
46+
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowFallbackToNoTLS: true, AllowNativePasswords: false, CheckConnLiveness: false},
4747
}, {
4848
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
4949
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
@@ -82,7 +82,7 @@ func TestDSNParser(t *testing.T) {
8282
}
8383

8484
// pointer not static
85-
cfg.tls = nil
85+
cfg.TLS = nil
8686

8787
if !reflect.DeepEqual(cfg, tst.out) {
8888
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
@@ -118,7 +118,7 @@ func TestDSNReformat(t *testing.T) {
118118
t.Error(err.Error())
119119
continue
120120
}
121-
cfg1.tls = nil // pointer not static
121+
cfg1.TLS = nil // pointer not static
122122
res1 := fmt.Sprintf("%+v", cfg1)
123123

124124
dsn2 := cfg1.FormatDSN()
@@ -127,7 +127,7 @@ func TestDSNReformat(t *testing.T) {
127127
t.Error(err.Error())
128128
continue
129129
}
130-
cfg2.tls = nil // pointer not static
130+
cfg2.TLS = nil // pointer not static
131131
res2 := fmt.Sprintf("%+v", cfg2)
132132

133133
if res1 != res2 {
@@ -203,7 +203,7 @@ func TestDSNWithCustomTLS(t *testing.T) {
203203

204204
if err != nil {
205205
t.Error(err.Error())
206-
} else if cfg.tls.ServerName != name {
206+
} else if cfg.TLS.ServerName != name {
207207
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
208208
}
209209

@@ -214,7 +214,7 @@ func TestDSNWithCustomTLS(t *testing.T) {
214214

215215
if err != nil {
216216
t.Error(err.Error())
217-
} else if cfg.tls.ServerName != name {
217+
} else if cfg.TLS.ServerName != name {
218218
t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
219219
} else if tlsCfg.ServerName != "" {
220220
t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst)
@@ -229,23 +229,23 @@ func TestDSNTLSConfig(t *testing.T) {
229229
if err != nil {
230230
t.Error(err.Error())
231231
}
232-
if cfg.tls == nil {
232+
if cfg.TLS == nil {
233233
t.Error("cfg.tls should not be nil")
234234
}
235-
if cfg.tls.ServerName != expectedServerName {
236-
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
235+
if cfg.TLS.ServerName != expectedServerName {
236+
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.TLS.ServerName)
237237
}
238238

239239
dsn = "tcp(example.com)/?tls=true"
240240
cfg, err = ParseDSN(dsn)
241241
if err != nil {
242242
t.Error(err.Error())
243243
}
244-
if cfg.tls == nil {
244+
if cfg.TLS == nil {
245245
t.Error("cfg.tls should not be nil")
246246
}
247-
if cfg.tls.ServerName != expectedServerName {
248-
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
247+
if cfg.TLS.ServerName != expectedServerName {
248+
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.TLS.ServerName)
249249
}
250250
}
251251

@@ -262,7 +262,7 @@ func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
262262

263263
if err != nil {
264264
t.Error(err.Error())
265-
} else if cfg.tls.ServerName != name {
265+
} else if cfg.TLS.ServerName != name {
266266
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn)
267267
}
268268
}
@@ -335,12 +335,12 @@ func TestCloneConfig(t *testing.T) {
335335
t.Errorf("Config.Clone did not create a separate config struct")
336336
}
337337

338-
if cfg2.tls.ServerName != expectedServerName {
339-
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
338+
if cfg2.TLS.ServerName != expectedServerName {
339+
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.TLS.ServerName)
340340
}
341341

342-
cfg2.tls.ServerName = "example2.com"
343-
if cfg.tls.ServerName == cfg2.tls.ServerName {
342+
cfg2.TLS.ServerName = "example2.com"
343+
if cfg.TLS.ServerName == cfg2.TLS.ServerName {
344344
t.Errorf("changed cfg.tls.Server name should not propagate to original Config")
345345
}
346346

@@ -384,20 +384,20 @@ func TestNormalizeTLSConfig(t *testing.T) {
384384

385385
cfg.normalize()
386386

387-
if cfg.tls == nil {
387+
if cfg.TLS == nil {
388388
if tc.want != nil {
389389
t.Fatal("wanted a tls config but got nil instead")
390390
}
391391
return
392392
}
393393

394-
if cfg.tls.ServerName != tc.want.ServerName {
394+
if cfg.TLS.ServerName != tc.want.ServerName {
395395
t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')",
396-
tc.want.ServerName, cfg.tls.ServerName)
396+
tc.want.ServerName, cfg.TLS.ServerName)
397397
}
398-
if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify {
398+
if cfg.TLS.InsecureSkipVerify != tc.want.InsecureSkipVerify {
399399
t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)",
400-
tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify)
400+
tc.want.InsecureSkipVerify, cfg.TLS.InsecureSkipVerify)
401401
}
402402
})
403403
}

packets.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,9 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
222222
if mc.flags&clientProtocol41 == 0 {
223223
return nil, "", ErrOldProtocol
224224
}
225-
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
226-
if mc.cfg.TLSConfig == "preferred" {
227-
mc.cfg.tls = nil
225+
if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
226+
if mc.cfg.AllowFallbackToNoTLS {
227+
mc.cfg.TLS = nil
228228
} else {
229229
return nil, "", ErrNoTLS
230230
}
@@ -292,7 +292,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
292292
}
293293

294294
// To enable TLS / SSL
295-
if mc.cfg.tls != nil {
295+
if mc.cfg.TLS != nil {
296296
clientFlags |= clientSSL
297297
}
298298

@@ -356,14 +356,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
356356

357357
// SSL Connection Request Packet
358358
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
359-
if mc.cfg.tls != nil {
359+
if mc.cfg.TLS != nil {
360360
// Send TLS / SSL request packet
361361
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
362362
return err
363363
}
364364

365365
// Switch to TLS
366-
tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
366+
tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
367367
if err := tlsConn.Handshake(); err != nil {
368368
return err
369369
}

0 commit comments

Comments
 (0)