diff --git a/client/auth.go b/client/auth.go index ff7ebcd01..e9c25e78b 100644 --- a/client/auth.go +++ b/client/auth.go @@ -201,7 +201,8 @@ func (c *Conn) writeAuthHandshake() error { // in the library are supported here capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE | c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS | - c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS + c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS | + c.ccaps&CLIENT_COMPRESS | c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM // To enable TLS / SSL if c.tlsConfig != nil { @@ -247,6 +248,9 @@ func (c *Conn) writeAuthHandshake() error { capability |= CLIENT_CONNECT_ATTRS length += len(attrData) } + if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { + length++ + } data := make([]byte, length+4) @@ -320,7 +324,12 @@ func (c *Conn) writeAuthHandshake() error { // connection attributes if len(attrData) > 0 { - copy(data[pos:], attrData) + pos += copy(data[pos:], attrData) + } + + if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { + // zstd_compression_level + data[pos] = 0x03 } return c.WritePacket(data) diff --git a/client/conn.go b/client/conn.go index 13a517636..b1f3e52d1 100644 --- a/client/conn.go +++ b/client/conn.go @@ -121,6 +121,12 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st return nil, errors.Trace(err) } + if c.ccaps&CLIENT_COMPRESS > 0 { + c.Conn.Compression = MYSQL_COMPRESS_ZLIB + } else if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { + c.Conn.Compression = MYSQL_COMPRESS_ZSTD + } + return c, nil } @@ -149,6 +155,13 @@ func (c *Conn) Close() error { return c.Conn.Close() } +func (c *Conn) Quit() error { + if err := c.writeCommand(COM_QUIT); err != nil { + return err + } + return c.Close() +} + func (c *Conn) Ping() error { if err := c.writeCommand(COM_PING); err != nil { return errors.Trace(err) diff --git a/mysql/const.go b/mysql/const.go index a1a5bde42..34661294a 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -185,3 +185,9 @@ const ( MYSQL_OPTION_MULTI_STATEMENTS_ON = iota MYSQL_OPTION_MULTI_STATEMENTS_OFF ) + +const ( + MYSQL_COMPRESS_NONE = iota + MYSQL_COMPRESS_ZLIB + MYSQL_COMPRESS_ZSTD +) diff --git a/packet/conn.go b/packet/conn.go index 8d020fe92..963b89e98 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -3,6 +3,7 @@ package packet import ( "bufio" "bytes" + "compress/zlib" "crypto/rand" "crypto/rsa" "crypto/sha1" @@ -12,6 +13,7 @@ import ( "net" "sync" + "github.com/DataDog/zstd" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" "github.com/pingcap/errors" @@ -56,6 +58,16 @@ type Conn struct { header [4]byte Sequence uint8 + + Compression uint8 + + CompressedSequence uint8 + + compressedHeader [7]byte + + compressedReaderActive bool + + compressedReader io.Reader } func NewConn(conn net.Conn) *Conn { @@ -94,8 +106,43 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { utils.BytesBufferPut(buf) }() - if err := c.ReadPacketTo(buf); err != nil { - return nil, errors.Trace(err) + if c.Compression != MYSQL_COMPRESS_NONE { + if !c.compressedReaderActive { + if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { + return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) + } + + compressedSequence := c.compressedHeader[3] + uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16) + if compressedSequence != c.CompressedSequence { + return nil, errors.Errorf("invalid compressed sequence %d != %d", + compressedSequence, c.CompressedSequence) + } + + if uncompressedLength > 0 { + var err error + switch c.Compression { + case MYSQL_COMPRESS_ZLIB: + c.compressedReader, err = zlib.NewReader(c.reader) + case MYSQL_COMPRESS_ZSTD: + c.compressedReader = zstd.NewReader(c.reader) + } + if err != nil { + return nil, err + } + } + c.compressedReaderActive = true + } + } + + if c.compressedReader != nil { + if err := c.ReadPacketTo(buf, c.compressedReader); err != nil { + return nil, errors.Trace(err) + } + } else { + if err := c.ReadPacketTo(buf, c.reader); err != nil { + return nil, errors.Trace(err) + } } readBytes := buf.Bytes() @@ -145,8 +192,8 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err return written, nil } -func (c *Conn) ReadPacketTo(w io.Writer) error { - if _, err := io.ReadFull(c.reader, c.header[:4]); err != nil { +func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error { + if _, err := io.ReadFull(r, c.header[:4]); err != nil { return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err) } @@ -164,7 +211,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { buf.Grow(length) } - if n, err := c.copyN(w, c.reader, int64(length)); err != nil { + if n, err := c.copyN(w, r, int64(length)); err != nil { return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length) } else if n != int64(length) { return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length) @@ -173,7 +220,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { return nil } - if err := c.ReadPacketTo(w); err != nil { + if err = c.ReadPacketTo(w, r); err != nil { return errors.Wrap(err, "ReadPacketTo failed") } } @@ -209,14 +256,93 @@ func (c *Conn) WritePacket(data []byte) error { data[2] = byte(length >> 16) data[3] = c.Sequence - if n, err := c.Write(data); err != nil { - return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) - } else if n != len(data) { - return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) + switch c.Compression { + case MYSQL_COMPRESS_NONE: + if n, err := c.Write(data); err != nil { + return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) + } else if n != len(data) { + return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) + } + case MYSQL_COMPRESS_ZLIB, MYSQL_COMPRESS_ZSTD: + if n, err := c.writeCompressed(data); err != nil { + return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) + } else if n != len(data) { + return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) + } + c.compressedReader = nil + c.compressedReaderActive = false + default: + return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set") + } + + c.Sequence++ + return nil +} + +func (c *Conn) writeCompressed(data []byte) (n int, err error) { + var compressedLength, uncompressedLength int + var payload, compressedPacket bytes.Buffer + var w io.WriteCloser + minCompressLength := 50 + compressedHeader := make([]byte, 7) + + switch c.Compression { + case MYSQL_COMPRESS_ZLIB: + w, err = zlib.NewWriterLevel(&payload, zlib.HuffmanOnly) + case MYSQL_COMPRESS_ZSTD: + w = zstd.NewWriter(&payload) + } + if err != nil { + return 0, err + } + + if len(data) > minCompressLength { + uncompressedLength = len(data) + n, err = w.Write(data) + if err != nil { + return 0, err + } + err = w.Close() + if err != nil { + return 0, err + } + } + + if len(data) > minCompressLength { + compressedLength = len(payload.Bytes()) + } else { + compressedLength = len(data) + } + + c.CompressedSequence = 0 + compressedHeader[0] = byte(compressedLength) + compressedHeader[1] = byte(compressedLength >> 8) + compressedHeader[2] = byte(compressedLength >> 16) + compressedHeader[3] = c.CompressedSequence + compressedHeader[4] = byte(uncompressedLength) + compressedHeader[5] = byte(uncompressedLength >> 8) + compressedHeader[6] = byte(uncompressedLength >> 16) + _, err = compressedPacket.Write(compressedHeader) + if err != nil { + return 0, err + } + c.CompressedSequence++ + + if len(data) > minCompressLength { + _, err = compressedPacket.Write(payload.Bytes()) } else { - c.Sequence++ - return nil + n, err = compressedPacket.Write(data) + } + if err != nil { + return 0, err } + + _, err = c.Write(compressedPacket.Bytes()) + if err != nil { + return 0, err + } + + return n, nil } // WriteClearAuthPacket: Client clear text authentication packet