Skip to content

Commit f0eee81

Browse files
committed
readPacket optimization
Since configuration options doesn't change at runtime, after connection is established, use dedicated function, in order to avoid multiple test test compress, checking ReadTimeout configuration option
1 parent 98f445c commit f0eee81

8 files changed

+105
-35
lines changed

Diff for: AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Daniel Montoya <dsmontoyam at gmail.com>
3737
Daniel Nichter <nil at codenode.com>
3838
Daniël van Eeden <git at myname.nl>
3939
Dave Protasowski <dprotaso at gmail.com>
40+
Diego Dupin <diego.dupin at gmail.com>
4041
Dirkjan Bussink <d.bussink at gmail.com>
4142
DisposaBoy <disposaboy at dby.me>
4243
Egor Smolyakov <egorsmkv at gmail.com>

Diff for: benchmark_test.go

+41
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,47 @@ func benchmarkQueryHelper(b *testing.B, compr bool) {
113113
}
114114
}
115115

116+
func BenchmarkSelect10000rows(b *testing.B) {
117+
db := initDB(b, false)
118+
defer db.Close()
119+
120+
// Check if we're using MariaDB
121+
var version string
122+
err := db.QueryRow("SELECT @@version").Scan(&version)
123+
if err != nil {
124+
b.Fatalf("Failed to get server version: %v", err)
125+
}
126+
127+
if !strings.Contains(strings.ToLower(version), "mariadb") {
128+
b.Skip("Skipping benchmark as it requires MariaDB sequence table")
129+
return
130+
}
131+
132+
b.StartTimer()
133+
stmt, err := db.Prepare("SELECT * FROM seq_1_to_10000")
134+
if err != nil {
135+
b.Fatalf("Failed to prepare statement: %v", err)
136+
}
137+
defer stmt.Close()
138+
for n := 0; n < b.N; n++ {
139+
rows, err := stmt.Query()
140+
if err != nil {
141+
b.Fatalf("Failed to query 10000rows: %v", err)
142+
}
143+
144+
var id int64
145+
for rows.Next() {
146+
err = rows.Scan(&id)
147+
if err != nil {
148+
rows.Close()
149+
b.Fatalf("Failed to scan row: %v", err)
150+
}
151+
}
152+
rows.Close()
153+
}
154+
b.StopTimer()
155+
}
156+
116157
func BenchmarkExec(b *testing.B) {
117158
tb := (*TB)(b)
118159
b.StopTimer()

Diff for: compress_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []by
4040
conn := new(mockConn)
4141
conn.data = compressedPacket
4242
mc.netConn = conn
43+
mc.readNextFunc = mc.compIO.readNext
44+
mc.readFunc = conn.Read
4345

4446
uncompressedPacket, err := mc.readPacket()
4547
if err != nil {

Diff for: connection.go

+3-11
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ type mysqlConn struct {
3939
compressSequence uint8
4040
parseTime bool
4141
compress bool
42+
readFunc func([]byte) (int, error)
43+
readNextFunc func(int, readerFunc) ([]byte, error)
4244

4345
// for context support (Go 1.8+)
4446
watching bool
@@ -64,16 +66,6 @@ func (mc *mysqlConn) log(v ...any) {
6466
mc.cfg.Logger.Print(v...)
6567
}
6668

67-
func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) {
68-
to := mc.cfg.ReadTimeout
69-
if to > 0 {
70-
if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil {
71-
return 0, err
72-
}
73-
}
74-
return mc.netConn.Read(b)
75-
}
76-
7769
func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) {
7870
to := mc.cfg.WriteTimeout
7971
if to > 0 {
@@ -247,7 +239,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
247239
// can not take the buffer. Something must be wrong with the connection
248240
mc.cleanup()
249241
// interpolateParams would be called before sending any query.
250-
// So its safe to retry.
242+
// So it's safe to retry.
251243
return "", driver.ErrBadConn
252244
}
253245
buf = buf[:0]

Diff for: connection_test.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@ import (
1818
)
1919

2020
func TestInterpolateParams(t *testing.T) {
21+
buf := newBuffer()
22+
nc := &net.TCPConn{}
2123
mc := &mysqlConn{
22-
buf: newBuffer(),
24+
buf: buf,
25+
netConn: nc,
2326
maxAllowedPacket: maxPacketSize,
2427
cfg: &Config{
2528
InterpolateParams: true,
2629
},
30+
readNextFunc: buf.readNext,
31+
readFunc: nc.Read,
2732
}
2833

2934
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})

Diff for: connector.go

+18
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"os"
1717
"strconv"
1818
"strings"
19+
"time"
1920
)
2021

2122
type connector struct {
@@ -130,6 +131,22 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
130131

131132
mc.buf = newBuffer()
132133

134+
// setting readNext/read functions
135+
mc.readNextFunc = mc.buf.readNext
136+
137+
// Initialize read function based on configuration
138+
if mc.cfg.ReadTimeout > 0 {
139+
mc.readFunc = func(b []byte) (int, error) {
140+
deadline := time.Now().Add(mc.cfg.ReadTimeout)
141+
if err := mc.netConn.SetReadDeadline(deadline); err != nil {
142+
return 0, err
143+
}
144+
return mc.netConn.Read(b)
145+
}
146+
} else {
147+
mc.readFunc = mc.netConn.Read
148+
}
149+
133150
// Reading Handshake Initialization Packet
134151
authData, plugin, err := mc.readHandshakePacket()
135152
if err != nil {
@@ -170,6 +187,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
170187
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
171188
mc.compress = true
172189
mc.compIO = newCompIO(mc)
190+
mc.readNextFunc = mc.compIO.readNext
173191
}
174192
if mc.cfg.MaxAllowedPacket > 0 {
175193
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket

Diff for: packets.go

+3-7
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
3030
var prevData []byte
3131
invalidSequence := false
3232

33-
readNext := mc.buf.readNext
34-
if mc.compress {
35-
readNext = mc.compIO.readNext
36-
}
37-
3833
for {
3934
// read packet header
40-
data, err := readNext(4, mc.readWithTimeout)
35+
data, err := mc.readNextFunc(4, mc.readFunc)
4136
if err != nil {
4237
mc.close()
4338
if cerr := mc.canceled.Value(); cerr != nil {
@@ -85,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
8580
}
8681

8782
// read packet body [pktLen bytes]
88-
data, err = readNext(pktLen, mc.readWithTimeout)
83+
data, err = mc.readNextFunc(pktLen, mc.readFunc)
8984
if err != nil {
9085
mc.close()
9186
if cerr := mc.canceled.Value(); cerr != nil {
@@ -369,6 +364,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
369364
return err
370365
}
371366
mc.netConn = tlsConn
367+
mc.readFunc = mc.netConn.Read
372368
}
373369

374370
// User [null terminated string]

Diff for: packets_test.go

+31-16
Original file line numberDiff line numberDiff line change
@@ -97,24 +97,30 @@ var _ net.Conn = new(mockConn)
9797
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
9898
conn := new(mockConn)
9999
connector := newConnector(NewConfig())
100+
buf := newBuffer()
100101
mc := &mysqlConn{
101-
buf: newBuffer(),
102+
buf: buf,
102103
cfg: connector.cfg,
103104
connector: connector,
104105
netConn: conn,
105106
closech: make(chan struct{}),
106107
maxAllowedPacket: defaultMaxAllowedPacket,
107108
sequence: sequence,
109+
readNextFunc: buf.readNext,
110+
readFunc: conn.Read,
108111
}
109112
return conn, mc
110113
}
111114

112115
func TestReadPacketSingleByte(t *testing.T) {
113116
conn := new(mockConn)
117+
buf := newBuffer()
114118
mc := &mysqlConn{
115-
netConn: conn,
116-
buf: newBuffer(),
117-
cfg: NewConfig(),
119+
netConn: conn,
120+
buf: buf,
121+
cfg: NewConfig(),
122+
readNextFunc: buf.readNext,
123+
readFunc: conn.Read,
118124
}
119125

120126
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
@@ -165,10 +171,13 @@ func TestReadPacketWrongSequenceID(t *testing.T) {
165171

166172
func TestReadPacketSplit(t *testing.T) {
167173
conn := new(mockConn)
174+
buf := newBuffer()
168175
mc := &mysqlConn{
169-
netConn: conn,
170-
buf: newBuffer(),
171-
cfg: NewConfig(),
176+
netConn: conn,
177+
buf: buf,
178+
cfg: NewConfig(),
179+
readNextFunc: buf.readNext,
180+
readFunc: conn.Read,
172181
}
173182

174183
data := make([]byte, maxPacketSize*2+4*3)
@@ -272,11 +281,14 @@ func TestReadPacketSplit(t *testing.T) {
272281

273282
func TestReadPacketFail(t *testing.T) {
274283
conn := new(mockConn)
284+
buf := newBuffer()
275285
mc := &mysqlConn{
276-
netConn: conn,
277-
buf: newBuffer(),
278-
closech: make(chan struct{}),
279-
cfg: NewConfig(),
286+
netConn: conn,
287+
buf: buf,
288+
closech: make(chan struct{}),
289+
cfg: NewConfig(),
290+
readNextFunc: buf.readNext,
291+
readFunc: conn.Read,
280292
}
281293

282294
// illegal empty (stand-alone) packet
@@ -317,12 +329,15 @@ func TestReadPacketFail(t *testing.T) {
317329
// not-NUL terminated plugin_name in init packet
318330
func TestRegression801(t *testing.T) {
319331
conn := new(mockConn)
332+
buf := newBuffer()
320333
mc := &mysqlConn{
321-
netConn: conn,
322-
buf: newBuffer(),
323-
cfg: new(Config),
324-
sequence: 42,
325-
closech: make(chan struct{}),
334+
netConn: conn,
335+
buf: buf,
336+
cfg: new(Config),
337+
sequence: 42,
338+
closech: make(chan struct{}),
339+
readNextFunc: buf.readNext,
340+
readFunc: conn.Read,
326341
}
327342

328343
conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,

0 commit comments

Comments
 (0)