Skip to content

Commit 700db22

Browse files
committed
compress: better buffer reuse
1 parent 3062a2f commit 700db22

9 files changed

+51
-60
lines changed

benchmark_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ func BenchmarkInterpolation(b *testing.B) {
236236
maxWriteSize: maxPacketSize - 1,
237237
buf: newBuffer(nil),
238238
}
239-
mc.packetReader = &mc.buf
239+
mc.packetRW = &mc.buf
240240

241241
args := []driver.Value{
242242
int64(42424242),

buffer.go

+6
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,9 @@ func (b *buffer) store(buf []byte) {
155155
b.cachedBuf = buf[:cap(buf)]
156156
}
157157
}
158+
159+
// writePackets is a proxy function to nc.Write.
160+
// This is used to make the buffer type compatible with compressed I/O.
161+
func (b *buffer) writePackets(packets []byte) (int, error) {
162+
return b.nc.Write(packets)
163+
}

compress.go

+23-32
Original file line numberDiff line numberDiff line change
@@ -85,34 +85,28 @@ func zCompress(src []byte, dst io.Writer) error {
8585
return nil
8686
}
8787

88-
type decompressor struct {
89-
mc *mysqlConn
90-
// read buffer (FIFO).
91-
// We can not reuse already-read buffer until dropping Go 1.20 support.
92-
// It is because of database/mysql's weired behavior.
93-
// See https://github.com/go-sql-driver/mysql/issues/1435
94-
bytesBuf []byte
88+
type compIO struct {
89+
mc *mysqlConn
90+
buff bytes.Buffer
9591
}
9692

97-
func newDecompressor(mc *mysqlConn) *decompressor {
98-
return &decompressor{
93+
func newCompIO(mc *mysqlConn) *compIO {
94+
return &compIO{
9995
mc: mc,
10096
}
10197
}
10298

103-
func (c *decompressor) readNext(need int) ([]byte, error) {
104-
for len(c.bytesBuf) < need {
99+
func (c *compIO) readNext(need int) ([]byte, error) {
100+
for c.buff.Len() < need {
105101
if err := c.uncompressPacket(); err != nil {
106102
return nil, err
107103
}
108104
}
109-
110-
data := c.bytesBuf[:need:need] // prevent caller writes into r.bytesBuf
111-
c.bytesBuf = c.bytesBuf[need:]
112-
return data, nil
105+
data := c.buff.Next(need)
106+
return data[:need:need], nil // prevent caller writes into c.buff
113107
}
114108

115-
func (c *decompressor) uncompressPacket() error {
109+
func (c *compIO) uncompressPacket() error {
116110
header, err := c.mc.buf.readNext(7) // size of compressed header
117111
if err != nil {
118112
return err
@@ -147,41 +141,37 @@ func (c *decompressor) uncompressPacket() error {
147141
// if payload is uncompressed, its length will be specified as zero, and its
148142
// true length is contained in comprLength
149143
if uncompressedLength == 0 {
150-
c.bytesBuf = append(c.bytesBuf, comprData...)
144+
c.buff.Write(comprData)
151145
return nil
152146
}
153147

154148
// use existing capacity in bytesBuf if possible
155-
offset := len(c.bytesBuf)
156-
if cap(c.bytesBuf)-offset < uncompressedLength {
157-
old := c.bytesBuf
158-
c.bytesBuf = make([]byte, offset, offset+uncompressedLength)
159-
copy(c.bytesBuf, old)
160-
}
161-
162-
lenRead, err := zDecompress(comprData, c.bytesBuf[offset:offset+uncompressedLength])
149+
c.buff.Grow(uncompressedLength)
150+
dec := c.buff.AvailableBuffer()[:uncompressedLength]
151+
lenRead, err := zDecompress(comprData, dec)
163152
if err != nil {
164153
return err
165154
}
166155
if lenRead != uncompressedLength {
167156
return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d",
168157
uncompressedLength, lenRead)
169158
}
170-
c.bytesBuf = c.bytesBuf[:offset+uncompressedLength]
159+
c.buff.Write(dec) // fast copy. See bytes.Buffer.AvailableBuffer() doc.
171160
return nil
172161
}
173162

174163
const maxPayloadLen = maxPacketSize - 4
175164

176-
// writeCompressed sends one or some packets with compression.
165+
// writePackets sends one or some packets with compression.
177166
// Use this instead of mc.netConn.Write() when mc.compress is true.
178-
func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) {
167+
func (c *compIO) writePackets(packets []byte) (int, error) {
179168
totalBytes := len(packets)
180169
dataLen := len(packets)
181170
blankHeader := make([]byte, 7)
182-
var buf bytes.Buffer
171+
buf := &c.buff
183172

184173
for dataLen > 0 {
174+
buf.Reset()
185175
payloadLen := dataLen
186176
if payloadLen > maxPayloadLen {
187177
payloadLen = maxPayloadLen
@@ -200,10 +190,10 @@ func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) {
200190
}
201191
uncompressedLen = 0
202192
} else {
203-
zCompress(payload, &buf)
193+
zCompress(payload, buf)
204194
}
205195

206-
if err := mc.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
196+
if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
207197
return 0, err
208198
}
209199
dataLen -= payloadLen
@@ -216,7 +206,8 @@ func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) {
216206

217207
// writeCompressedPacket writes a compressed packet with header.
218208
// data should start with 7 size space for header followed by payload.
219-
func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error {
209+
func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error {
210+
mc := c.mc
220211
comprLength := len(data) - 7
221212
if debugTrace {
222213
fmt.Printf(

compress_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ func makeRandByteSlice(size int) []byte {
2626
func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte {
2727
conn := new(mockConn)
2828
mc.netConn = conn
29+
comp := newCompIO(mc)
2930

30-
n, err := mc.writeCompressed(uncompressedPacket)
31+
n, err := comp.writePackets(uncompressedPacket)
3132
if err != nil {
3233
t.Fatal(err)
3334
}
@@ -43,7 +44,7 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS
4344
conn := new(mockConn)
4445
conn.data = compressedPacket
4546
mc.buf.nc = conn
46-
cr := newDecompressor(mc)
47+
cr := newCompIO(mc)
4748

4849
uncompressedPacket, err := cr.readNext(expSize)
4950
if err != nil {

connection.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ type mysqlConn struct {
2828
netConn net.Conn
2929
rawConn net.Conn // underlying connection when netConn is TLS connection.
3030
result mysqlResult // managed by clearResult() and handleOkPacket().
31-
packetReader packetReader
31+
packetRW packetIO
3232
cfg *Config
3333
connector *connector
3434
maxAllowedPacket int
@@ -65,8 +65,9 @@ func (mc *mysqlConn) log(v ...any) {
6565
mc.cfg.Logger.Print(v...)
6666
}
6767

68-
type packetReader interface {
68+
type packetIO interface {
6969
readNext(need int) ([]byte, error)
70+
writePackets(data []byte) (int, error)
7071
}
7172

7273
func (mc *mysqlConn) resetSequenceNr() {

connection_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func TestInterpolateParams(t *testing.T) {
2525
InterpolateParams: true,
2626
},
2727
}
28-
mc.packetReader = &mc.buf
28+
mc.packetRW = &mc.buf
2929

3030
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
3131
if err != nil {
@@ -73,7 +73,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
7373
InterpolateParams: true,
7474
},
7575
}
76-
mc.packetReader = &mc.buf
76+
mc.packetRW = &mc.buf
7777

7878
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
7979
if err != driver.ErrSkip {
@@ -92,7 +92,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) {
9292
},
9393
}
9494

95-
mc.packetReader = &mc.buf
95+
mc.packetRW = &mc.buf
9696

9797
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
9898
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
@@ -168,7 +168,7 @@ func TestPingMarkBadConnection(t *testing.T) {
168168
mc := &mysqlConn{
169169
netConn: nc,
170170
buf: buf,
171-
packetReader: &buf,
171+
packetRW: &buf,
172172
maxAllowedPacket: defaultMaxAllowedPacket,
173173
closech: make(chan struct{}),
174174
cfg: NewConfig(),
@@ -188,7 +188,7 @@ func TestPingErrInvalidConn(t *testing.T) {
188188
mc := &mysqlConn{
189189
netConn: nc,
190190
buf: buf,
191-
packetReader: &buf,
191+
packetRW: &buf,
192192
maxAllowedPacket: defaultMaxAllowedPacket,
193193
closech: make(chan struct{}),
194194
cfg: NewConfig(),

connector.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
129129

130130
mc.buf = newBuffer(mc.netConn)
131131
// packet reader and writer in handshake are never compressed
132-
mc.packetReader = &mc.buf
132+
mc.packetRW = &mc.buf
133133

134134
// Set I/O timeouts
135135
mc.buf.timeout = mc.cfg.ReadTimeout
@@ -174,7 +174,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
174174

175175
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
176176
mc.compress = true
177-
mc.packetReader = newDecompressor(mc)
177+
mc.packetRW = newCompIO(mc)
178178
}
179179
if mc.cfg.MaxAllowedPacket > 0 {
180180
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket

packets.go

+3-11
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
3232

3333
for {
3434
// read packet header
35-
data, err := mc.packetReader.readNext(4)
35+
data, err := mc.packetRW.readNext(4)
3636
if err != nil {
3737
mc.close()
3838
if cerr := mc.canceled.Value(); cerr != nil {
@@ -80,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
8080
}
8181

8282
// read packet body [pktLen bytes]
83-
data, err = mc.packetReader.readNext(pktLen)
83+
data, err = mc.packetRW.readNext(pktLen)
8484
if err != nil {
8585
mc.close()
8686
if cerr := mc.canceled.Value(); cerr != nil {
@@ -143,15 +143,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
143143
}
144144
}
145145

146-
var (
147-
n int
148-
err error
149-
)
150-
if mc.compress {
151-
n, err = mc.writeCompressed(data[:4+size])
152-
} else {
153-
n, err = mc.netConn.Write(data[:4+size])
154-
}
146+
n, err := mc.packetRW.writePackets(data[:4+size])
155147
if err != nil {
156148
mc.cleanup()
157149
if cerr := mc.canceled.Value(); cerr != nil {

packets_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
100100
buf := newBuffer(conn)
101101
mc := &mysqlConn{
102102
buf: buf,
103-
packetReader: &buf,
103+
packetRW: &buf,
104104
cfg: connector.cfg,
105105
connector: connector,
106106
netConn: conn,
@@ -116,7 +116,7 @@ func TestReadPacketSingleByte(t *testing.T) {
116116
mc := &mysqlConn{
117117
buf: newBuffer(conn),
118118
}
119-
mc.packetReader = &mc.buf
119+
mc.packetRW = &mc.buf
120120

121121
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
122122
conn.maxReads = 1
@@ -169,7 +169,7 @@ func TestReadPacketSplit(t *testing.T) {
169169
mc := &mysqlConn{
170170
buf: newBuffer(conn),
171171
}
172-
mc.packetReader = &mc.buf
172+
mc.packetRW = &mc.buf
173173

174174
data := make([]byte, maxPacketSize*2+4*3)
175175
const pkt2ofs = maxPacketSize + 4
@@ -277,7 +277,7 @@ func TestReadPacketFail(t *testing.T) {
277277
closech: make(chan struct{}),
278278
cfg: NewConfig(),
279279
}
280-
mc.packetReader = &mc.buf
280+
mc.packetRW = &mc.buf
281281

282282
// illegal empty (stand-alone) packet
283283
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
@@ -323,7 +323,7 @@ func TestRegression801(t *testing.T) {
323323
sequence: 42,
324324
closech: make(chan struct{}),
325325
}
326-
mc.packetReader = &mc.buf
326+
mc.packetRW = &mc.buf
327327

328328
conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
329329
60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0,

0 commit comments

Comments
 (0)