Skip to content

fixing bad connection error when reading large compressed packets #863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 7, 2024
Merged
3 changes: 0 additions & 3 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ linters:
disable-all: true
enable:
# All code is ready for:
- deadcode
- errcheck
- staticcheck
- structcheck
- typecheck
- unused
- varcheck
- misspell
- nolintlint
- goimports
Expand Down
114 changes: 79 additions & 35 deletions packet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/sha1"
"crypto/x509"
"encoding/pem"
goErrors "errors"
"io"
"net"
"sync"
Expand Down Expand Up @@ -108,41 +109,17 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {

if c.Compression != MYSQL_COMPRESS_NONE {
if !c.compressedReaderActive {
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
}

compressedSequence := c.compressedHeader[3]
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
if compressedSequence != c.CompressedSequence {
return nil, errors.Errorf("invalid compressed sequence %d != %d",
compressedSequence, c.CompressedSequence)
}

if uncompressedLength > 0 {
var err error
switch c.Compression {
case MYSQL_COMPRESS_ZLIB:
c.compressedReader, err = zlib.NewReader(c.reader)
case MYSQL_COMPRESS_ZSTD:
c.compressedReader, err = zstd.NewReader(c.reader)
}
if err != nil {
return nil, err
}
var err error
c.compressedReader, err = c.newCompressedPacketReader()
if err != nil {
return nil, err
}
c.compressedReaderActive = true
}
}

if c.compressedReader != nil {
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil {
return nil, errors.Trace(err)
}
} else {
if err := c.ReadPacketTo(buf, c.reader); err != nil {
return nil, errors.Trace(err)
}
if err := c.ReadPacketTo(buf); err != nil {
return nil, errors.Trace(err)
}

readBytes := buf.Bytes()
Expand All @@ -167,6 +144,41 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
return result, nil
}

// newCompressedPacketReader creates a new compressed packet reader.
func (c *Conn) newCompressedPacketReader() (io.Reader, error) {
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
}

compressedSequence := c.compressedHeader[3]
if compressedSequence != c.CompressedSequence {
return nil, errors.Errorf("invalid compressed sequence %d != %d",
compressedSequence, c.CompressedSequence)
}

compressedLength := int(uint32(c.compressedHeader[0]) | uint32(c.compressedHeader[1])<<8 | uint32(c.compressedHeader[2])<<16)
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
if uncompressedLength > 0 {
limitedReader := io.LimitReader(c.reader, int64(compressedLength))
switch c.Compression {
case MYSQL_COMPRESS_ZLIB:
return zlib.NewReader(limitedReader)
case MYSQL_COMPRESS_ZSTD:
return zstd.NewReader(limitedReader)
}
}

return nil, nil
}

func (c *Conn) currentPacketReader() io.Reader {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there're

	compressedReaderActive bool

	compressedReader io.Reader

in Conn. Seems we can directly check c.compressedReader == nil as the returned reader for currentPacketReader. And compressedReaderActive always has the same value for c.compressedReader == nil so we can delete it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I attempted to delete compressedReaderActive the tests in the client package all began failing when I ran them with compression enabled. I think this is because compressedReaderActive is reset to false in WritePacket after writing the compressed packet. So I don't think I can delete it, or at least I feel deleting it is out of scope for this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this PR with your suggestion. It was 2am for me and I wasn't thinking clearly, but after more sleep, I realized I could easily remove the compressedReaderActive boolean property from Conn.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take care of your health ❤️

if c.Compression == MYSQL_COMPRESS_NONE || c.compressedReader == nil {
return c.reader
} else {
return c.compressedReader
}
}

func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
for n > 0 {
bcap := cap(c.copyNBuf)
Expand All @@ -175,9 +187,29 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
}
buf := c.copyNBuf[:bcap]

rd, err := io.ReadAtLeast(src, buf, bcap)
var rd int
rd, err = io.ReadAtLeast(src, buf, bcap)

n -= int64(rd)

// if we've read to EOF, and we have compression then advance the sequence number
// and reset the compressed reader to continue reading the remaining bytes
// in the next compressed packet.
if c.Compression != MYSQL_COMPRESS_NONE && rd < bcap &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just read the comment of ReadAtLeast, seems for goErrors.Is(err, io.ErrUnexpectedEOF) || goErrors.Is(err, io.EOF) we don't need to check rd < bcap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agree with you after reviewing the ReadAtLeast documentation. I've applied the change suggested.

(goErrors.Is(err, io.ErrUnexpectedEOF) || goErrors.Is(err, io.EOF)) {
// we have read to EOF and read an incomplete uncompressed packet
// so advance the compressed sequence number and reset the compressed reader
// to get the remaining unread uncompressed bytes from the next compressed packet.
c.CompressedSequence++
if c.compressedReader, err = c.newCompressedPacketReader(); err != nil {
return written, errors.Trace(err)
}

// now read the remaining bytes into the buffer containing the first read bytes
rd, err = io.ReadAtLeast(c.currentPacketReader(), buf[rd:], bcap-rd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we delete this reading and let the outer loop read it? because it may still meet the EOF error like line 191.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes terrific suggestion, i've applied this change.

n -= int64(rd)
}

if err != nil {
return written, errors.Trace(err)
}
Expand All @@ -192,9 +224,21 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
return written, nil
}

func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
if _, err := io.ReadFull(r, c.header[:4]); err != nil {
func (c *Conn) ReadPacketTo(w io.Writer) error {
b := utils.BytesBufferGet()
defer func() {
utils.BytesBufferPut(b)
}()

// packets that come in a compressed packet may be partial
// so use the copyN function to read the packet header into a
// buffer, since copyN is capable of getting the next compressed
// packet and updating the Conn state with a new compressedReader.
if _, err := c.copyN(b, c.currentPacketReader(), 4); err != nil {
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
} else {
// copy was successful so copy the 4 bytes from the buffer to the header
copy(c.header[:4], b.Bytes()[:4])
}

length := int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16)
Expand All @@ -211,7 +255,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
buf.Grow(length)
}

if n, err := c.copyN(w, r, int64(length)); err != nil {
if n, err := c.copyN(w, c.currentPacketReader(), int64(length)); err != nil {
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
} else if n != int64(length) {
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)
Expand All @@ -220,7 +264,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
return nil
}

if err = c.ReadPacketTo(w, r); err != nil {
if err = c.ReadPacketTo(w); err != nil {
return errors.Wrap(err, "ReadPacketTo failed")
}
}
Expand Down
Loading