Skip to content

Commit 4b6c230

Browse files
committed
Encode connection attribute only once.
1 parent 4e1c200 commit 4b6c230

File tree

5 files changed

+63
-38
lines changed

5 files changed

+63
-38
lines changed

Diff for: connection.go

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type mysqlConn struct {
2727
affectedRows uint64
2828
insertId uint64
2929
cfg *Config
30+
connector *connector
3031
maxAllowedPacket int
3132
maxWriteSize int
3233
writeTimeout time.Duration

Diff for: connector.go

+45-1
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,54 @@ package mysql
1111
import (
1212
"context"
1313
"database/sql/driver"
14+
"fmt"
1415
"net"
16+
"os"
17+
"strconv"
18+
"strings"
1519
)
1620

1721
type connector struct {
18-
cfg *Config // immutable private copy.
22+
cfg *Config // immutable private copy.
23+
encodedAttributes string // Encoded connection attributes.
24+
}
25+
26+
func encodeConnectionAttributes(textAttributes string) string {
27+
connAttrsBuf := make([]byte, 0, 251)
28+
29+
// default connection attributes
30+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
31+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
32+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
33+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
34+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
35+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
36+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
37+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
38+
39+
// user-defined connection attributes
40+
for _, connAttr := range strings.Split(textAttributes, ",") {
41+
attr := strings.SplitN(connAttr, ":", 2)
42+
if len(attr) != 2 {
43+
continue
44+
}
45+
for _, v := range attr {
46+
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
47+
}
48+
}
49+
50+
return string(connAttrsBuf)
51+
}
52+
53+
func newConnector(cfg *Config) (*connector, error) {
54+
encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes)
55+
if len(encodedAttributes) > 250 {
56+
return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes)
57+
}
58+
return &connector{
59+
cfg: cfg,
60+
encodedAttributes: encodedAttributes,
61+
}, nil
1962
}
2063

2164
// Connect implements driver.Connector interface.
@@ -29,6 +72,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
2972
maxWriteSize: maxPacketSize - 1,
3073
closech: make(chan struct{}),
3174
cfg: c.cfg,
75+
connector: c,
3276
}
3377
mc.parseTime = mc.cfg.ParseTime
3478

Diff for: connector_test.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@ import (
88
)
99

1010
func TestConnectorReturnsTimeout(t *testing.T) {
11-
connector := &connector{&Config{
11+
connector, err := newConnector(&Config{
1212
Net: "tcp",
1313
Addr: "1.1.1.1:1234",
1414
Timeout: 10 * time.Millisecond,
15-
}}
15+
})
16+
if err != nil {
17+
t.Fatal(err)
18+
}
1619

1720
_, err := connector.Connect(context.Background())
1821
if err == nil {

Diff for: driver.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
8585
if err != nil {
8686
return nil, err
8787
}
88-
c := &connector{
89-
cfg: cfg,
88+
c, err := newConnector(cfg)
89+
if err != nil {
90+
return nil, err
9091
}
9192
return c.Connect(context.Background())
9293
}
@@ -103,7 +104,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) {
103104
if err := cfg.normalize(); err != nil {
104105
return nil, err
105106
}
106-
return &connector{cfg: cfg}, nil
107+
return newConnector(cfg)
107108
}
108109

109110
// OpenConnector implements driver.DriverContext.
@@ -112,7 +113,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
112113
if err != nil {
113114
return nil, err
114115
}
115-
return &connector{
116-
cfg: cfg,
117-
}, nil
116+
return newConnector(cfg)
118117
}

Diff for: packets.go

+7-29
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ import (
1818
"fmt"
1919
"io"
2020
"math"
21-
"os"
22-
"strconv"
23-
"strings"
2421
"time"
2522
)
2623

@@ -322,31 +319,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
322319
pktLen += n + 1
323320
}
324321

325-
connAttrsBuf := make([]byte, 0, 100)
326-
327-
// default connection attributes
328-
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
329-
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
330-
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
331-
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
332-
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
333-
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
334-
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
335-
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
336-
337-
// user-defined connection attributes
338-
for _, connAttr := range strings.Split(mc.cfg.ConnectionAttributes, ",") {
339-
attr := strings.Split(connAttr, ":")
340-
if len(attr) != 2 {
341-
continue
342-
}
343-
for _, v := range attr {
344-
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
345-
}
346-
}
347-
348322
// 1 byte to store length of all key-values
349-
pktLen += len(connAttrsBuf) + 1
323+
// NOTE: Actually, this is length encoded integer.
324+
// But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer
325+
// doesn't support buffer size more than 4096 bytes.
326+
// TODO(methane): Rewrite buffer management.
327+
pktLen += 1 + len(mc.connector.encodedAttributes)
350328

351329
// Calculate packet length and get buffer with that size
352330
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
@@ -425,9 +403,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
425403
pos++
426404

427405
// Connection Attributes
428-
data[pos] = byte(len(connAttrsBuf))
406+
data[pos] = byte(len(mc.connector.encodedAttributes))
429407
pos++
430-
pos += copy(data[pos:], connAttrsBuf)
408+
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
431409

432410
// Send Auth packet
433411
return mc.writePacket(data[:pos])

0 commit comments

Comments
 (0)