Skip to content

Commit 29bc749

Browse files
authored
Merge pull request #787 from dveeden/mysql_proto_compressed
Support MySQL Compressed Protocol
2 parents cfe8571 + 419796e commit 29bc749

File tree

4 files changed

+168
-14
lines changed

4 files changed

+168
-14
lines changed

client/auth.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ func (c *Conn) writeAuthHandshake() error {
201201
// in the library are supported here
202202
capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE |
203203
c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS |
204-
c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS
204+
c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS |
205+
c.ccaps&CLIENT_COMPRESS | c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM
205206

206207
// To enable TLS / SSL
207208
if c.tlsConfig != nil {
@@ -247,6 +248,9 @@ func (c *Conn) writeAuthHandshake() error {
247248
capability |= CLIENT_CONNECT_ATTRS
248249
length += len(attrData)
249250
}
251+
if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
252+
length++
253+
}
250254

251255
data := make([]byte, length+4)
252256

@@ -320,7 +324,12 @@ func (c *Conn) writeAuthHandshake() error {
320324

321325
// connection attributes
322326
if len(attrData) > 0 {
323-
copy(data[pos:], attrData)
327+
pos += copy(data[pos:], attrData)
328+
}
329+
330+
if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
331+
// zstd_compression_level
332+
data[pos] = 0x03
324333
}
325334

326335
return c.WritePacket(data)

client/conn.go

+13
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st
121121
return nil, errors.Trace(err)
122122
}
123123

124+
if c.ccaps&CLIENT_COMPRESS > 0 {
125+
c.Conn.Compression = MYSQL_COMPRESS_ZLIB
126+
} else if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
127+
c.Conn.Compression = MYSQL_COMPRESS_ZSTD
128+
}
129+
124130
return c, nil
125131
}
126132

@@ -149,6 +155,13 @@ func (c *Conn) Close() error {
149155
return c.Conn.Close()
150156
}
151157

158+
func (c *Conn) Quit() error {
159+
if err := c.writeCommand(COM_QUIT); err != nil {
160+
return err
161+
}
162+
return c.Close()
163+
}
164+
152165
func (c *Conn) Ping() error {
153166
if err := c.writeCommand(COM_PING); err != nil {
154167
return errors.Trace(err)

mysql/const.go

+6
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,9 @@ const (
185185
MYSQL_OPTION_MULTI_STATEMENTS_ON = iota
186186
MYSQL_OPTION_MULTI_STATEMENTS_OFF
187187
)
188+
189+
const (
190+
MYSQL_COMPRESS_NONE = iota
191+
MYSQL_COMPRESS_ZLIB
192+
MYSQL_COMPRESS_ZSTD
193+
)

packet/conn.go

+138-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package packet
33
import (
44
"bufio"
55
"bytes"
6+
"compress/zlib"
67
"crypto/rand"
78
"crypto/rsa"
89
"crypto/sha1"
@@ -12,6 +13,7 @@ import (
1213
"net"
1314
"sync"
1415

16+
"github.com/DataDog/zstd"
1517
. "github.com/go-mysql-org/go-mysql/mysql"
1618
"github.com/go-mysql-org/go-mysql/utils"
1719
"github.com/pingcap/errors"
@@ -56,6 +58,16 @@ type Conn struct {
5658
header [4]byte
5759

5860
Sequence uint8
61+
62+
Compression uint8
63+
64+
CompressedSequence uint8
65+
66+
compressedHeader [7]byte
67+
68+
compressedReaderActive bool
69+
70+
compressedReader io.Reader
5971
}
6072

6173
func NewConn(conn net.Conn) *Conn {
@@ -94,8 +106,43 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
94106
utils.BytesBufferPut(buf)
95107
}()
96108

97-
if err := c.ReadPacketTo(buf); err != nil {
98-
return nil, errors.Trace(err)
109+
if c.Compression != MYSQL_COMPRESS_NONE {
110+
if !c.compressedReaderActive {
111+
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
112+
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
113+
}
114+
115+
compressedSequence := c.compressedHeader[3]
116+
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
117+
if compressedSequence != c.CompressedSequence {
118+
return nil, errors.Errorf("invalid compressed sequence %d != %d",
119+
compressedSequence, c.CompressedSequence)
120+
}
121+
122+
if uncompressedLength > 0 {
123+
var err error
124+
switch c.Compression {
125+
case MYSQL_COMPRESS_ZLIB:
126+
c.compressedReader, err = zlib.NewReader(c.reader)
127+
case MYSQL_COMPRESS_ZSTD:
128+
c.compressedReader = zstd.NewReader(c.reader)
129+
}
130+
if err != nil {
131+
return nil, err
132+
}
133+
}
134+
c.compressedReaderActive = true
135+
}
136+
}
137+
138+
if c.compressedReader != nil {
139+
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil {
140+
return nil, errors.Trace(err)
141+
}
142+
} else {
143+
if err := c.ReadPacketTo(buf, c.reader); err != nil {
144+
return nil, errors.Trace(err)
145+
}
99146
}
100147

101148
readBytes := buf.Bytes()
@@ -145,8 +192,8 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
145192
return written, nil
146193
}
147194

148-
func (c *Conn) ReadPacketTo(w io.Writer) error {
149-
if _, err := io.ReadFull(c.reader, c.header[:4]); err != nil {
195+
func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
196+
if _, err := io.ReadFull(r, c.header[:4]); err != nil {
150197
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
151198
}
152199

@@ -164,7 +211,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
164211
buf.Grow(length)
165212
}
166213

167-
if n, err := c.copyN(w, c.reader, int64(length)); err != nil {
214+
if n, err := c.copyN(w, r, int64(length)); err != nil {
168215
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
169216
} else if n != int64(length) {
170217
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)
@@ -173,7 +220,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
173220
return nil
174221
}
175222

176-
if err := c.ReadPacketTo(w); err != nil {
223+
if err = c.ReadPacketTo(w, r); err != nil {
177224
return errors.Wrap(err, "ReadPacketTo failed")
178225
}
179226
}
@@ -209,14 +256,93 @@ func (c *Conn) WritePacket(data []byte) error {
209256
data[2] = byte(length >> 16)
210257
data[3] = c.Sequence
211258

212-
if n, err := c.Write(data); err != nil {
213-
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
214-
} else if n != len(data) {
215-
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
259+
switch c.Compression {
260+
case MYSQL_COMPRESS_NONE:
261+
if n, err := c.Write(data); err != nil {
262+
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
263+
} else if n != len(data) {
264+
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
265+
}
266+
case MYSQL_COMPRESS_ZLIB, MYSQL_COMPRESS_ZSTD:
267+
if n, err := c.writeCompressed(data); err != nil {
268+
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
269+
} else if n != len(data) {
270+
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
271+
}
272+
c.compressedReader = nil
273+
c.compressedReaderActive = false
274+
default:
275+
return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set")
276+
}
277+
278+
c.Sequence++
279+
return nil
280+
}
281+
282+
func (c *Conn) writeCompressed(data []byte) (n int, err error) {
283+
var compressedLength, uncompressedLength int
284+
var payload, compressedPacket bytes.Buffer
285+
var w io.WriteCloser
286+
minCompressLength := 50
287+
compressedHeader := make([]byte, 7)
288+
289+
switch c.Compression {
290+
case MYSQL_COMPRESS_ZLIB:
291+
w, err = zlib.NewWriterLevel(&payload, zlib.HuffmanOnly)
292+
case MYSQL_COMPRESS_ZSTD:
293+
w = zstd.NewWriter(&payload)
294+
}
295+
if err != nil {
296+
return 0, err
297+
}
298+
299+
if len(data) > minCompressLength {
300+
uncompressedLength = len(data)
301+
n, err = w.Write(data)
302+
if err != nil {
303+
return 0, err
304+
}
305+
err = w.Close()
306+
if err != nil {
307+
return 0, err
308+
}
309+
}
310+
311+
if len(data) > minCompressLength {
312+
compressedLength = len(payload.Bytes())
313+
} else {
314+
compressedLength = len(data)
315+
}
316+
317+
c.CompressedSequence = 0
318+
compressedHeader[0] = byte(compressedLength)
319+
compressedHeader[1] = byte(compressedLength >> 8)
320+
compressedHeader[2] = byte(compressedLength >> 16)
321+
compressedHeader[3] = c.CompressedSequence
322+
compressedHeader[4] = byte(uncompressedLength)
323+
compressedHeader[5] = byte(uncompressedLength >> 8)
324+
compressedHeader[6] = byte(uncompressedLength >> 16)
325+
_, err = compressedPacket.Write(compressedHeader)
326+
if err != nil {
327+
return 0, err
328+
}
329+
c.CompressedSequence++
330+
331+
if len(data) > minCompressLength {
332+
_, err = compressedPacket.Write(payload.Bytes())
216333
} else {
217-
c.Sequence++
218-
return nil
334+
n, err = compressedPacket.Write(data)
335+
}
336+
if err != nil {
337+
return 0, err
219338
}
339+
340+
_, err = c.Write(compressedPacket.Bytes())
341+
if err != nil {
342+
return 0, err
343+
}
344+
345+
return n, nil
220346
}
221347

222348
// WriteClearAuthPacket: Client clear text authentication packet

0 commit comments

Comments
 (0)