Skip to content

Commit 7263925

Browse files
committed
Support MySQL Compressed Protocol
1 parent 850a82b commit 7263925

File tree

4 files changed

+171
-14
lines changed

4 files changed

+171
-14
lines changed

client/auth.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ func (c *Conn) writeAuthHandshake() error {
201201
// in the library are supported here
202202
capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE |
203203
c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS |
204-
c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS
204+
c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS |
205+
c.ccaps&CLIENT_COMPRESS | c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM
205206

206207
// To enable TLS / SSL
207208
if c.tlsConfig != nil {
@@ -247,6 +248,9 @@ func (c *Conn) writeAuthHandshake() error {
247248
capability |= CLIENT_CONNECT_ATTRS
248249
length += len(attrData)
249250
}
251+
if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
252+
length++
253+
}
250254

251255
data := make([]byte, length+4)
252256

@@ -320,7 +324,13 @@ func (c *Conn) writeAuthHandshake() error {
320324

321325
// connection attributes
322326
if len(attrData) > 0 {
323-
copy(data[pos:], attrData)
327+
pos += copy(data[pos:], attrData)
328+
}
329+
330+
if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
331+
// zstd_compression_level
332+
data[pos] = 0x03
333+
pos++
324334
}
325335

326336
return c.WritePacket(data)

client/conn.go

+13
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st
112112
return nil, errors.Trace(err)
113113
}
114114

115+
if c.ccaps&CLIENT_COMPRESS > 0 {
116+
c.Conn.Compression = MYSQL_COMPRESS_ZLIB
117+
} else if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
118+
c.Conn.Compression = MYSQL_COMPRESS_ZSTD
119+
}
120+
115121
return c, nil
116122
}
117123

@@ -140,6 +146,13 @@ func (c *Conn) Close() error {
140146
return c.Conn.Close()
141147
}
142148

149+
func (c *Conn) Quit() error {
150+
if err := c.writeCommand(COM_QUIT); err != nil {
151+
return err
152+
}
153+
return c.Close()
154+
}
155+
143156
func (c *Conn) Ping() error {
144157
if err := c.writeCommand(COM_PING); err != nil {
145158
return errors.Trace(err)

mysql/const.go

+6
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,9 @@ const (
185185
MYSQL_OPTION_MULTI_STATEMENTS_ON = iota
186186
MYSQL_OPTION_MULTI_STATEMENTS_OFF
187187
)
188+
189+
const (
190+
MYSQL_COMPRESS_NONE = iota
191+
MYSQL_COMPRESS_ZLIB
192+
MYSQL_COMPRESS_ZSTD
193+
)

packet/conn.go

+140-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package packet
33
import (
44
"bufio"
55
"bytes"
6+
"compress/zlib"
67
"crypto/rand"
78
"crypto/rsa"
89
"crypto/sha1"
@@ -12,6 +13,7 @@ import (
1213
"net"
1314
"sync"
1415

16+
"github.com/DataDog/zstd"
1517
. "github.com/go-mysql-org/go-mysql/mysql"
1618
"github.com/go-mysql-org/go-mysql/utils"
1719
"github.com/pingcap/errors"
@@ -56,6 +58,16 @@ type Conn struct {
5658
header [4]byte
5759

5860
Sequence uint8
61+
62+
Compression uint8
63+
64+
CompressedSequence uint8
65+
66+
compressedHeader [7]byte
67+
68+
compressedReaderActive bool
69+
70+
compressedReader io.Reader
5971
}
6072

6173
func NewConn(conn net.Conn) *Conn {
@@ -94,8 +106,43 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
94106
utils.BytesBufferPut(buf)
95107
}()
96108

97-
if err := c.ReadPacketTo(buf); err != nil {
98-
return nil, errors.Trace(err)
109+
if c.Compression != MYSQL_COMPRESS_NONE {
110+
if !c.compressedReaderActive {
111+
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
112+
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
113+
}
114+
115+
compressedSequence := c.compressedHeader[3]
116+
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
117+
if compressedSequence != c.CompressedSequence {
118+
return nil, errors.Errorf("invalid compressed sequence %d != %d",
119+
compressedSequence, c.CompressedSequence)
120+
}
121+
122+
if uncompressedLength > 0 {
123+
var err error
124+
switch c.Compression {
125+
case MYSQL_COMPRESS_ZLIB:
126+
c.compressedReader, err = zlib.NewReader(c.reader)
127+
case MYSQL_COMPRESS_ZSTD:
128+
c.compressedReader = zstd.NewReader(c.reader)
129+
}
130+
if err != nil {
131+
return nil, err
132+
}
133+
}
134+
c.compressedReaderActive = true
135+
}
136+
}
137+
138+
if c.compressedReader != nil {
139+
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil {
140+
return nil, errors.Trace(err)
141+
}
142+
} else {
143+
if err := c.ReadPacketTo(buf, c.reader); err != nil {
144+
return nil, errors.Trace(err)
145+
}
99146
}
100147

101148
readBytes := buf.Bytes()
@@ -145,8 +192,8 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
145192
return written, nil
146193
}
147194

148-
func (c *Conn) ReadPacketTo(w io.Writer) error {
149-
if _, err := io.ReadFull(c.reader, c.header[:4]); err != nil {
195+
func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
196+
if _, err := io.ReadFull(r, c.header[:4]); err != nil {
150197
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
151198
}
152199

@@ -164,7 +211,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
164211
buf.Grow(length)
165212
}
166213

167-
if n, err := c.copyN(w, c.reader, int64(length)); err != nil {
214+
if n, err := c.copyN(w, r, int64(length)); err != nil {
168215
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
169216
} else if n != int64(length) {
170217
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)
@@ -173,7 +220,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
173220
return nil
174221
}
175222

176-
if err := c.ReadPacketTo(w); err != nil {
223+
if err = c.ReadPacketTo(w, r); err != nil {
177224
return errors.Wrap(err, "ReadPacketTo failed")
178225
}
179226
}
@@ -209,14 +256,95 @@ func (c *Conn) WritePacket(data []byte) error {
209256
data[2] = byte(length >> 16)
210257
data[3] = c.Sequence
211258

212-
if n, err := c.Write(data); err != nil {
213-
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
214-
} else if n != len(data) {
215-
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
259+
switch c.Compression {
260+
case MYSQL_COMPRESS_NONE:
261+
if n, err := c.Write(data); err != nil {
262+
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
263+
} else if n != len(data) {
264+
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
265+
}
266+
case MYSQL_COMPRESS_ZLIB:
267+
fallthrough
268+
case MYSQL_COMPRESS_ZSTD:
269+
if n, err := c.writeCompressed(data); err != nil {
270+
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
271+
} else if n != len(data) {
272+
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
273+
}
274+
c.compressedReader = nil
275+
c.compressedReaderActive = false
276+
default:
277+
return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set")
278+
}
279+
280+
c.Sequence++
281+
return nil
282+
}
283+
284+
func (c *Conn) writeCompressed(data []byte) (n int, err error) {
285+
var compressedLength, uncompressedLength int
286+
var payload, compressedPacket bytes.Buffer
287+
var w io.WriteCloser
288+
minCompressLength := 50
289+
compressedHeader := make([]byte, 7)
290+
291+
switch c.Compression {
292+
case MYSQL_COMPRESS_ZLIB:
293+
w, err = zlib.NewWriterLevel(&payload, zlib.HuffmanOnly)
294+
case MYSQL_COMPRESS_ZSTD:
295+
w = zstd.NewWriter(&payload)
296+
}
297+
if err != nil {
298+
return 0, err
299+
}
300+
301+
if len(data) > minCompressLength {
302+
uncompressedLength = len(data)
303+
n, err = w.Write(data)
304+
if err != nil {
305+
return 0, err
306+
}
307+
err = w.Close()
308+
if err != nil {
309+
return 0, err
310+
}
311+
}
312+
313+
if len(data) > minCompressLength {
314+
compressedLength = len(payload.Bytes())
315+
} else {
316+
compressedLength = len(data)
317+
}
318+
319+
c.CompressedSequence = 0
320+
compressedHeader[0] = byte(compressedLength)
321+
compressedHeader[1] = byte(compressedLength >> 8)
322+
compressedHeader[2] = byte(compressedLength >> 16)
323+
compressedHeader[3] = c.CompressedSequence
324+
compressedHeader[4] = byte(uncompressedLength)
325+
compressedHeader[5] = byte(uncompressedLength >> 8)
326+
compressedHeader[6] = byte(uncompressedLength >> 16)
327+
_, err = compressedPacket.Write(compressedHeader)
328+
if err != nil {
329+
return 0, err
330+
}
331+
c.CompressedSequence++
332+
333+
if len(data) > minCompressLength {
334+
_, err = compressedPacket.Write(payload.Bytes())
216335
} else {
217-
c.Sequence++
218-
return nil
336+
n, err = compressedPacket.Write(data)
337+
}
338+
if err != nil {
339+
return 0, err
219340
}
341+
342+
_, err = c.Write(compressedPacket.Bytes())
343+
if err != nil {
344+
return 0, err
345+
}
346+
347+
return n, nil
220348
}
221349

222350
// WriteClearAuthPacket: Client clear text authentication packet

0 commit comments

Comments
 (0)