Skip to content

Commit 6611466

Browse files
Daemonxiaomethane
authored and
Pavlo
committed
Send connection attributes (go-sql-driver#1389)
Co-authored-by: Inada Naoki <[email protected]>
1 parent f5e5b2a commit 6611466

12 files changed

+186
-39
lines changed

Diff for: .github/workflows/test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ jobs:
7878
; TestConcurrent fails if max_connections is too large
7979
max_connections=50
8080
local_infile=1
81+
performance_schema=on
8182
- name: setup database
8283
run: |
8384
mysql --user 'root' --host '127.0.0.1' -e 'create database gotest;'

Diff for: README.md

+9
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,15 @@ Default: 0
382382

383383
I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
384384

385+
##### `connectionAttributes`
386+
387+
```
388+
Type: comma-delimited string of user-defined "key:value" pairs
389+
Valid Values: (<name1>:<value1>,<name2>:<value2>,...)
390+
Default: none
391+
```
392+
393+
[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time.
385394

386395
##### System Variables
387396

Diff for: connection.go

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type mysqlConn struct {
2727
affectedRows uint64
2828
insertId uint64
2929
cfg *Config
30+
connector *connector
3031
maxAllowedPacket int
3132
maxWriteSize int
3233
writeTimeout time.Duration

Diff for: connector.go

+45-1
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,54 @@ package mysql
1111
import (
1212
"context"
1313
"database/sql/driver"
14+
"fmt"
1415
"net"
16+
"os"
17+
"strconv"
18+
"strings"
1519
)
1620

1721
type connector struct {
18-
cfg *Config // immutable private copy.
22+
cfg *Config // immutable private copy.
23+
encodedAttributes string // Encoded connection attributes.
24+
}
25+
26+
func encodeConnectionAttributes(textAttributes string) string {
27+
connAttrsBuf := make([]byte, 0, 251)
28+
29+
// default connection attributes
30+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
31+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
32+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
33+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
34+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
35+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
36+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
37+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
38+
39+
// user-defined connection attributes
40+
for _, connAttr := range strings.Split(textAttributes, ",") {
41+
attr := strings.SplitN(connAttr, ":", 2)
42+
if len(attr) != 2 {
43+
continue
44+
}
45+
for _, v := range attr {
46+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
47+
}
48+
}
49+
50+
return string(connAttrsBuf)
51+
}
52+
53+
func newConnector(cfg *Config) (*connector, error) {
54+
encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes)
55+
if len(encodedAttributes) > 250 {
56+
return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes)
57+
}
58+
return &connector{
59+
cfg: cfg,
60+
encodedAttributes: encodedAttributes,
61+
}, nil
1962
}
2063

2164
// Connect implements driver.Connector interface.
@@ -29,6 +72,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
2972
maxWriteSize: maxPacketSize - 1,
3073
closech: make(chan struct{}),
3174
cfg: c.cfg,
75+
connector: c,
3276
}
3377
mc.parseTime = mc.cfg.ParseTime
3478

Diff for: connector_test.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ import (
88
)
99

1010
func TestConnectorReturnsTimeout(t *testing.T) {
11-
connector := &connector{&Config{
11+
connector, err := newConnector(&Config{
1212
Net: "tcp",
1313
Addr: "1.1.1.1:1234",
1414
Timeout: 10 * time.Millisecond,
15-
}}
15+
})
16+
if err != nil {
17+
t.Fatal(err)
18+
}
1619

17-
_, err := connector.Connect(context.Background())
20+
_, err = connector.Connect(context.Background())
1821
if err == nil {
1922
t.Fatal("error expected")
2023
}

Diff for: const.go

+12
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,24 @@
88

99
package mysql
1010

11+
import "runtime"
12+
1113
const (
1214
defaultAuthPlugin = "mysql_native_password"
1315
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
1416
minProtocolVersion = 10
1517
maxPacketSize = 1<<24 - 1
1618
timeFormat = "2006-01-02 15:04:05.999999"
19+
20+
// Connection attributes
21+
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
22+
connAttrClientName = "_client_name"
23+
connAttrClientNameValue = "Go-MySQL-Driver"
24+
connAttrOS = "_os"
25+
connAttrOSValue = runtime.GOOS
26+
connAttrPlatform = "_platform"
27+
connAttrPlatformValue = runtime.GOARCH
28+
connAttrPid = "_pid"
1729
)
1830

1931
// MySQL constants documentation:

Diff for: driver.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
7474
if err != nil {
7575
return nil, err
7676
}
77-
c := &connector{
78-
cfg: cfg,
77+
c, err := newConnector(cfg)
78+
if err != nil {
79+
return nil, err
7980
}
8081
return c.Connect(context.Background())
8182
}
@@ -92,7 +93,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) {
9293
if err := cfg.normalize(); err != nil {
9394
return nil, err
9495
}
95-
return &connector{cfg: cfg}, nil
96+
return newConnector(cfg)
9697
}
9798

9899
// OpenConnector implements driver.DriverContext.
@@ -101,7 +102,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
101102
if err != nil {
102103
return nil, err
103104
}
104-
return &connector{
105-
cfg: cfg,
106-
}, nil
105+
return newConnector(cfg)
107106
}

Diff for: driver_test.go

+47
Original file line numberDiff line numberDiff line change
@@ -3209,3 +3209,50 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) {
32093209
t.Errorf("connection not closed")
32103210
}
32113211
}
3212+
3213+
func TestConnectionAttributes(t *testing.T) {
3214+
if !available {
3215+
t.Skipf("MySQL server not running on %s", netAddr)
3216+
}
3217+
3218+
attr1 := "attr1"
3219+
value1 := "value1"
3220+
attr2 := "foo"
3221+
value2 := "boo"
3222+
dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2)
3223+
3224+
var db *sql.DB
3225+
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
3226+
db, err = sql.Open("mysql", dsn)
3227+
if err != nil {
3228+
t.Fatalf("error connecting: %s", err.Error())
3229+
}
3230+
defer db.Close()
3231+
}
3232+
3233+
dbt := &DBTest{t, db}
3234+
3235+
var attrValue string
3236+
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
3237+
rows := dbt.mustQuery(queryString, connAttrClientName)
3238+
if rows.Next() {
3239+
rows.Scan(&attrValue)
3240+
if attrValue != connAttrClientNameValue {
3241+
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
3242+
}
3243+
} else {
3244+
dbt.Errorf("no data")
3245+
}
3246+
rows.Close()
3247+
3248+
rows = dbt.mustQuery(queryString, attr2)
3249+
if rows.Next() {
3250+
rows.Scan(&attrValue)
3251+
if attrValue != value2 {
3252+
dbt.Errorf("expected %q, got %q", value2, attrValue)
3253+
}
3254+
} else {
3255+
dbt.Errorf("no data")
3256+
}
3257+
rows.Close()
3258+
}

Diff for: dsn.go

+36-28
Original file line numberDiff line numberDiff line change
@@ -34,34 +34,37 @@ var (
3434
// If a new Config is created instead of being parsed from a DSN string,
3535
// the NewConfig function should be used, which sets default values.
3636
type Config struct {
37-
User string // Username
38-
Passwd string // Password (requires User)
39-
Net string // Network type
40-
Addr string // Network address (requires Net)
41-
DBName string // Database name
42-
Params map[string]string // Connection parameters
43-
Collation string // Connection collation
44-
Loc *time.Location // Location for time.Time values
45-
MaxAllowedPacket int // Max packet size allowed
46-
ServerPubKey string // Server public key name
47-
pubKey *rsa.PublicKey // Server public key
48-
TLSConfig string // TLS configuration name
49-
tls *tls.Config // TLS configuration
50-
Timeout time.Duration // Dial timeout
51-
ReadTimeout time.Duration // I/O read timeout
52-
WriteTimeout time.Duration // I/O write timeout
53-
54-
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
55-
AllowCleartextPasswords bool // Allows the cleartext client side plugin
56-
AllowNativePasswords bool // Allows the native password authentication method
57-
AllowOldPasswords bool // Allows the old insecure password method
58-
CheckConnLiveness bool // Check connections for liveness before using them
59-
ClientFoundRows bool // Return number of matching rows instead of rows changed
60-
ColumnsWithAlias bool // Prepend table alias to column names
61-
InterpolateParams bool // Interpolate placeholders into query string
62-
MultiStatements bool // Allow multiple statements in one query
63-
ParseTime bool // Parse time values to time.Time
64-
RejectReadOnly bool // Reject read-only connections
37+
User string // Username
38+
Passwd string // Password (requires User)
39+
Net string // Network type
40+
Addr string // Network address (requires Net)
41+
DBName string // Database name
42+
Params map[string]string // Connection parameters
43+
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
44+
Collation string // Connection collation
45+
Loc *time.Location // Location for time.Time values
46+
MaxAllowedPacket int // Max packet size allowed
47+
ServerPubKey string // Server public key name
48+
pubKey *rsa.PublicKey // Server public key
49+
TLSConfig string // TLS configuration name
50+
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
51+
Timeout time.Duration // Dial timeout
52+
ReadTimeout time.Duration // I/O read timeout
53+
WriteTimeout time.Duration // I/O write timeout
54+
Logger Logger // Logger
55+
56+
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
57+
AllowCleartextPasswords bool // Allows the cleartext client side plugin
58+
AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
59+
AllowNativePasswords bool // Allows the native password authentication method
60+
AllowOldPasswords bool // Allows the old insecure password method
61+
CheckConnLiveness bool // Check connections for liveness before using them
62+
ClientFoundRows bool // Return number of matching rows instead of rows changed
63+
ColumnsWithAlias bool // Prepend table alias to column names
64+
InterpolateParams bool // Interpolate placeholders into query string
65+
MultiStatements bool // Allow multiple statements in one query
66+
ParseTime bool // Parse time values to time.Time
67+
RejectReadOnly bool // Reject read-only connections
6568
}
6669

6770
// NewConfig creates a new Config and sets default values.
@@ -537,6 +540,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
537540
if err != nil {
538541
return
539542
}
543+
544+
// Connection attributes
545+
case "connectionAttributes":
546+
cfg.ConnectionAttributes = value
547+
540548
default:
541549
// lazy init
542550
if cfg.Params == nil {

Diff for: packets.go

+13
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
286286
clientLocalFiles |
287287
clientPluginAuth |
288288
clientMultiResults |
289+
clientConnectAttrs |
289290
mc.flags&clientLongFlag
290291

291292
if mc.cfg.ClientFoundRows {
@@ -319,6 +320,13 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
319320
pktLen += n + 1
320321
}
321322

323+
// 1 byte to store length of all key-values
324+
// NOTE: Actually, this is length encoded integer.
325+
// But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer
326+
// doesn't support buffer size more than 4096 bytes.
327+
// TODO(methane): Rewrite buffer management.
328+
pktLen += 1 + len(mc.connector.encodedAttributes)
329+
322330
// Calculate packet length and get buffer with that size
323331
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
324332
if err != nil {
@@ -395,6 +403,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
395403
data[pos] = 0x00
396404
pos++
397405

406+
// Connection Attributes
407+
data[pos] = byte(len(mc.connector.encodedAttributes))
408+
pos++
409+
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
410+
398411
// Send Auth packet
399412
return mc.writePacket(data[:pos])
400413
}

Diff for: packets_test.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,14 @@ var _ net.Conn = new(mockConn)
9696

9797
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
9898
conn := new(mockConn)
99+
connector, err := newConnector(NewConfig())
100+
if err != nil {
101+
panic(err)
102+
}
99103
mc := &mysqlConn{
100104
buf: newBuffer(conn),
101-
cfg: NewConfig(),
105+
cfg: connector.cfg,
106+
connector: connector,
102107
netConn: conn,
103108
closech: make(chan struct{}),
104109
maxAllowedPacket: defaultMaxAllowedPacket,

Diff for: utils.go

+5
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
626626
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
627627
}
628628

629+
func appendLengthEncodedString(b []byte, s string) []byte {
630+
b = appendLengthEncodedInteger(b, uint64(len(s)))
631+
return append(b, s...)
632+
}
633+
629634
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
630635
// If cap(buf) is not enough, reallocate new buffer.
631636
func reserveBuffer(buf []byte, appendSize int) []byte {

0 commit comments

Comments
 (0)