Skip to content

Commit 5d80303

Browse files
committed
By default, if not explicitly set, user must be the current os user (mandatory for some authentication plugin, like GSSAPI for example).
1 parent 18a7ce2 commit 5d80303

File tree

5 files changed

+68
-8
lines changed

5 files changed

+68
-8
lines changed

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ db.SetMaxIdleConns(10)
101101

102102
The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets):
103103
```
104-
[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]
104+
[[username][:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]
105105
```
106106

107107
A DSN in its fullest form:

Diff for: auth_test.go

+43
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import (
1616
"encoding/pem"
1717
"fmt"
1818
"testing"
19+
20+
osuser "os/user"
1921
)
2022

2123
var testPubKey = []byte("-----BEGIN PUBLIC KEY-----\n" +
@@ -79,6 +81,47 @@ func TestScrambleSHA256Pass(t *testing.T) {
7981
}
8082
}
8183

84+
func TestDefaultUser(t *testing.T) {
85+
conn, mc := newRWMockConn(1)
86+
mc.cfg.User = ""
87+
mc.cfg.Passwd = "secret"
88+
89+
authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69,
90+
22, 41, 84, 32, 123, 43, 118}
91+
plugin := "mysql_native_password"
92+
93+
// Send Client Authentication Packet
94+
authPlugin, exists := globalPluginRegistry.GetPlugin(plugin)
95+
if !exists {
96+
t.Fatalf("plugin not registered")
97+
}
98+
var expectedUsername string
99+
currentUser, err := osuser.Current()
100+
if err != nil {
101+
expectedUsername = ""
102+
} else {
103+
expectedUsername = currentUser.Username
104+
}
105+
106+
authResp, err := authPlugin.InitAuth(authData, mc.cfg)
107+
if err != nil {
108+
t.Fatal(err)
109+
}
110+
err = mc.writeHandshakeResponsePacket(authResp, plugin)
111+
if err != nil {
112+
t.Fatal(err)
113+
}
114+
115+
// check written auth response
116+
authRespStart := 4 + 4 + 4 + 1 + 23
117+
authRespEnd := authRespStart + len(expectedUsername)
118+
writtenAuthResp := conn.written[authRespStart:authRespEnd]
119+
expectedAuthResp := []byte(expectedUsername)
120+
if !bytes.Equal(writtenAuthResp, expectedAuthResp) || conn.written[authRespEnd] != 0 {
121+
t.Fatalf("unexpected written auth response: %v", writtenAuthResp)
122+
}
123+
}
124+
82125
func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
83126
conn, mc := newRWMockConn(1)
84127
mc.cfg.User = "root"

Diff for: dsn.go

+8-4
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,18 @@ func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) {
256256
func (cfg *Config) FormatDSN() string {
257257
var buf bytes.Buffer
258258

259-
// [username[:password]@]
259+
// [[username][:password]@]
260260
if len(cfg.User) > 0 {
261261
buf.WriteString(cfg.User)
262262
if len(cfg.Passwd) > 0 {
263263
buf.WriteByte(':')
264264
buf.WriteString(cfg.Passwd)
265265
}
266266
buf.WriteByte('@')
267+
} else if len(cfg.Passwd) > 0 {
268+
buf.WriteByte(':')
269+
buf.WriteString(cfg.Passwd)
270+
buf.WriteByte('@')
267271
}
268272

269273
// [protocol[(address)]]
@@ -408,7 +412,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
408412
// New config with some default values
409413
cfg = NewConfig()
410414

411-
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
415+
// [[username][:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
412416
// Find the last '/' (since the password or the net addr might contain a '/')
413417
foundSlash := false
414418
for i := len(dsn) - 1; i >= 0; i-- {
@@ -418,11 +422,11 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
418422

419423
// left part is empty if i <= 0
420424
if i > 0 {
421-
// [username[:password]@][protocol[(address)]]
425+
// [[username][:password]@][protocol[(address)]]
422426
// Find the last '@' in dsn[:i]
423427
for j = i; j >= 0; j-- {
424428
if dsn[j] == '@' {
425-
// username[:password]
429+
// [username][:password]
426430
// Find the first ':' in dsn[:j]
427431
for k = 0; k < j; k++ {
428432
if dsn[k] == ':' {

Diff for: dsn_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ var testDSNs = []struct {
4747
}, {
4848
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
4949
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
50+
}, {
51+
":p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
52+
&Config{Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
5053
}, {
5154
"/dbname",
5255
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},

Diff for: packets.go

+13-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"fmt"
1818
"io"
1919
"math"
20+
osuser "os/user"
2021
"strconv"
2122
"time"
2223
)
@@ -303,8 +304,17 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
303304
// length encoded integer
304305
clientFlags |= clientPluginAuthLenEncClientData
305306
}
307+
var userName string
308+
if len(mc.cfg.User) > 0 {
309+
userName = mc.cfg.User
310+
} else {
311+
// Get current user if username is empty
312+
if currentUser, err := osuser.Current(); err == nil {
313+
userName = currentUser.Username
314+
}
315+
}
306316

307-
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
317+
pktLen := 4 + 4 + 1 + 23 + len(userName) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
308318

309319
// To specify a db name
310320
if n := len(mc.cfg.DBName); n > 0 {
@@ -372,8 +382,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
372382
}
373383

374384
// User [null terminated string]
375-
if len(mc.cfg.User) > 0 {
376-
pos += copy(data[pos:], mc.cfg.User)
385+
if len(userName) > 0 {
386+
pos += copy(data[pos:], userName)
377387
}
378388
data[pos] = 0x00
379389
pos++

0 commit comments

Comments
 (0)