Skip to content

Commit 2546947

Browse files
committed
Add connection attributes
1 parent 4591e42 commit 2546947

File tree

6 files changed

+125
-16
lines changed

6 files changed

+125
-16
lines changed

README.md

+9
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,15 @@ Default: 0
393393

394394
I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
395395

396+
##### `connectionAttributes`
397+
398+
```
399+
Type: comma-delimited string of user-defined "key:value" pairs
400+
Valid Values: (<name1>:<value1>,<name2>:<value2>,...)
401+
Default: none
402+
```
403+
404+
[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time.
396405

397406
##### System Variables
398407

const.go

+11
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,23 @@
88

99
package mysql
1010

11+
import "runtime"
12+
1113
const (
1214
defaultAuthPlugin = "mysql_native_password"
1315
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
1416
minProtocolVersion = 10
1517
maxPacketSize = 1<<24 - 1
1618
timeFormat = "2006-01-02 15:04:05.999999"
19+
20+
// Connection attributes
21+
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available
22+
connAttrClientName = "_client_name"
23+
connAttrClientNameValue = "GO-MySQL-Driver"
24+
connAttrOS = "_os"
25+
connAttrOSValue = runtime.GOOS
26+
connAttrPlatform = "_platform"
27+
connAttrPlatformValue = runtime.GOARCH
1728
)
1829

1930
// MySQL constants documentation:

driver_test.go

+47
Original file line numberDiff line numberDiff line change
@@ -3209,3 +3209,50 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) {
32093209
t.Errorf("connection not closed")
32103210
}
32113211
}
3212+
3213+
func TestConnectionAttributes(t *testing.T) {
3214+
if !available {
3215+
t.Skipf("MySQL server not running on %s", netAddr)
3216+
}
3217+
3218+
attr1 := "attr1"
3219+
value1 := "value1"
3220+
attr2 := "foo"
3221+
value2 := "boo"
3222+
dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2)
3223+
3224+
var db *sql.DB
3225+
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
3226+
db, err = sql.Open("mysql", dsn)
3227+
if err != nil {
3228+
t.Fatalf("error connecting: %s", err.Error())
3229+
}
3230+
defer db.Close()
3231+
}
3232+
3233+
dbt := &DBTest{t, db}
3234+
3235+
var attrValue string
3236+
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
3237+
rows := dbt.mustQuery(queryString, connAttrClientName)
3238+
if rows.Next() {
3239+
rows.Scan(&attrValue)
3240+
if attrValue != connAttrClientNameValue {
3241+
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
3242+
}
3243+
} else {
3244+
dbt.Errorf("no data")
3245+
}
3246+
rows.Close()
3247+
3248+
rows = dbt.mustQuery(queryString, attr2)
3249+
if rows.Next() {
3250+
rows.Scan(&attrValue)
3251+
if attrValue != value2 {
3252+
dbt.Errorf("expected %q, got %q", value2, attrValue)
3253+
}
3254+
} else {
3255+
dbt.Errorf("no data")
3256+
}
3257+
rows.Close()
3258+
}

dsn.go

+22-16
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,23 @@ var (
3434
// If a new Config is created instead of being parsed from a DSN string,
3535
// the NewConfig function should be used, which sets default values.
3636
type Config struct {
37-
User string // Username
38-
Passwd string // Password (requires User)
39-
Net string // Network type
40-
Addr string // Network address (requires Net)
41-
DBName string // Database name
42-
Params map[string]string // Connection parameters
43-
Collation string // Connection collation
44-
Loc *time.Location // Location for time.Time values
45-
MaxAllowedPacket int // Max packet size allowed
46-
ServerPubKey string // Server public key name
47-
pubKey *rsa.PublicKey // Server public key
48-
TLSConfig string // TLS configuration name
49-
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
50-
Timeout time.Duration // Dial timeout
51-
ReadTimeout time.Duration // I/O read timeout
52-
WriteTimeout time.Duration // I/O write timeout
37+
User string // Username
38+
Passwd string // Password (requires User)
39+
Net string // Network type
40+
Addr string // Network address (requires Net)
41+
DBName string // Database name
42+
Params map[string]string // Connection parameters
43+
Collation string // Connection collation
44+
Loc *time.Location // Location for time.Time values
45+
MaxAllowedPacket int // Max packet size allowed
46+
ServerPubKey string // Server public key name
47+
pubKey *rsa.PublicKey // Server public key
48+
TLSConfig string // TLS configuration name
49+
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
50+
Timeout time.Duration // Dial timeout
51+
ReadTimeout time.Duration // I/O read timeout
52+
WriteTimeout time.Duration // I/O write timeout
53+
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
5354

5455
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
5556
AllowCleartextPasswords bool // Allows the cleartext client side plugin
@@ -554,6 +555,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
554555
if err != nil {
555556
return
556557
}
558+
559+
// Connection attributes
560+
case "connectionAttributes":
561+
cfg.ConnectionAttributes = value
562+
557563
default:
558564
// lazy init
559565
if cfg.Params == nil {

packets.go

+31
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"fmt"
1919
"io"
2020
"math"
21+
"strings"
2122
"time"
2223
)
2324

@@ -285,6 +286,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
285286
clientLocalFiles |
286287
clientPluginAuth |
287288
clientMultiResults |
289+
clientConnectAttrs |
288290
mc.flags&clientLongFlag
289291

290292
if mc.cfg.ClientFoundRows {
@@ -318,6 +320,30 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
318320
pktLen += n + 1
319321
}
320322

323+
connAttrsBuf := make([]byte, 0, 100)
324+
325+
// default connection attributes
326+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
327+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
328+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
329+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
330+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
331+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
332+
333+
// user-defined connection attributes
334+
for _, connAttr := range strings.Split(mc.cfg.ConnectionAttributes, ",") {
335+
attr := strings.Split(connAttr, ":")
336+
if len(attr) != 2 {
337+
continue
338+
}
339+
for _, v := range attr {
340+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
341+
}
342+
}
343+
344+
// 1 byte to store length of all key-values
345+
pktLen += len(connAttrsBuf) + 1
346+
321347
// Calculate packet length and get buffer with that size
322348
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
323349
if err != nil {
@@ -394,6 +420,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
394420
data[pos] = 0x00
395421
pos++
396422

423+
// Connection Attributes
424+
data[pos] = byte(len(connAttrsBuf))
425+
pos++
426+
pos += copy(data[pos:], connAttrsBuf)
427+
397428
// Send Auth packet
398429
return mc.writePacket(data[:pos])
399430
}

utils.go

+5
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
616616
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
617617
}
618618

619+
func appendLengthEncodedString(b []byte, s string) []byte {
620+
b = appendLengthEncodedInteger(b, uint64(len(s)))
621+
return append(b, s...)
622+
}
623+
619624
// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
620625
// If cap(buf) is not enough, reallocate new buffer.
621626
func reserveBuffer(buf []byte, appendSize int) []byte {

0 commit comments

Comments
 (0)