-
Notifications
You must be signed in to change notification settings - Fork 1k
fixing bad connection error when reading large compressed packets #863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
d082741
88f99a9
4941c07
3f855bf
3d097f7
5895dda
22dbb42
5970c88
e372b06
c84b0fc
3848abb
c2232da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ import ( | |
"crypto/sha1" | ||
"crypto/x509" | ||
"encoding/pem" | ||
goErrors "errors" | ||
"io" | ||
"net" | ||
"sync" | ||
|
@@ -108,41 +109,17 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { | |
|
||
if c.Compression != MYSQL_COMPRESS_NONE { | ||
if !c.compressedReaderActive { | ||
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { | ||
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) | ||
} | ||
|
||
compressedSequence := c.compressedHeader[3] | ||
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16) | ||
if compressedSequence != c.CompressedSequence { | ||
return nil, errors.Errorf("invalid compressed sequence %d != %d", | ||
compressedSequence, c.CompressedSequence) | ||
} | ||
|
||
if uncompressedLength > 0 { | ||
var err error | ||
switch c.Compression { | ||
case MYSQL_COMPRESS_ZLIB: | ||
c.compressedReader, err = zlib.NewReader(c.reader) | ||
case MYSQL_COMPRESS_ZSTD: | ||
c.compressedReader, err = zstd.NewReader(c.reader) | ||
} | ||
if err != nil { | ||
return nil, err | ||
} | ||
var err error | ||
c.compressedReader, err = c.newCompressedPacketReader() | ||
if err != nil { | ||
return nil, err | ||
} | ||
c.compressedReaderActive = true | ||
} | ||
} | ||
|
||
if c.compressedReader != nil { | ||
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
} else { | ||
if err := c.ReadPacketTo(buf, c.reader); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
if err := c.ReadPacketTo(buf); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
readBytes := buf.Bytes() | ||
|
@@ -167,6 +144,41 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { | |
return result, nil | ||
} | ||
|
||
// newCompressedPacketReader creates a new compressed packet reader. | ||
func (c *Conn) newCompressedPacketReader() (io.Reader, error) { | ||
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { | ||
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) | ||
} | ||
|
||
compressedSequence := c.compressedHeader[3] | ||
if compressedSequence != c.CompressedSequence { | ||
return nil, errors.Errorf("invalid compressed sequence %d != %d", | ||
compressedSequence, c.CompressedSequence) | ||
} | ||
|
||
compressedLength := int(uint32(c.compressedHeader[0]) | uint32(c.compressedHeader[1])<<8 | uint32(c.compressedHeader[2])<<16) | ||
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16) | ||
if uncompressedLength > 0 { | ||
limitedReader := io.LimitReader(c.reader, int64(compressedLength)) | ||
switch c.Compression { | ||
case MYSQL_COMPRESS_ZLIB: | ||
return zlib.NewReader(limitedReader) | ||
case MYSQL_COMPRESS_ZSTD: | ||
return zstd.NewReader(limitedReader) | ||
} | ||
} | ||
|
||
return nil, nil | ||
} | ||
|
||
func (c *Conn) currentPacketReader() io.Reader { | ||
if c.Compression == MYSQL_COMPRESS_NONE || c.compressedReader == nil { | ||
return c.reader | ||
} else { | ||
return c.compressedReader | ||
} | ||
} | ||
|
||
func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) { | ||
for n > 0 { | ||
bcap := cap(c.copyNBuf) | ||
|
@@ -175,9 +187,29 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err | |
} | ||
buf := c.copyNBuf[:bcap] | ||
|
||
rd, err := io.ReadAtLeast(src, buf, bcap) | ||
var rd int | ||
rd, err = io.ReadAtLeast(src, buf, bcap) | ||
|
||
n -= int64(rd) | ||
|
||
// if we've read to EOF, and we have compression then advance the sequence number | ||
// and reset the compressed reader to continue reading the remaining bytes | ||
// in the next compressed packet. | ||
if c.Compression != MYSQL_COMPRESS_NONE && rd < bcap && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just read the comment of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes agree with you after reviewing the |
||
(goErrors.Is(err, io.ErrUnexpectedEOF) || goErrors.Is(err, io.EOF)) { | ||
// we have read to EOF and read an incomplete uncompressed packet | ||
// so advance the compressed sequence number and reset the compressed reader | ||
// to get the remaining unread uncompressed bytes from the next compressed packet. | ||
c.CompressedSequence++ | ||
if c.compressedReader, err = c.newCompressedPacketReader(); err != nil { | ||
return written, errors.Trace(err) | ||
} | ||
|
||
// now read the remaining bytes into the buffer containing the first read bytes | ||
rd, err = io.ReadAtLeast(c.currentPacketReader(), buf[rd:], bcap-rd) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we delete this reading and let the outer loop read it? because it may still meet the EOF error like line 191. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes terrific suggestion, i've applied this change. |
||
n -= int64(rd) | ||
} | ||
|
||
if err != nil { | ||
return written, errors.Trace(err) | ||
} | ||
|
@@ -192,9 +224,21 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err | |
return written, nil | ||
} | ||
|
||
func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error { | ||
if _, err := io.ReadFull(r, c.header[:4]); err != nil { | ||
func (c *Conn) ReadPacketTo(w io.Writer) error { | ||
b := utils.BytesBufferGet() | ||
defer func() { | ||
utils.BytesBufferPut(b) | ||
}() | ||
|
||
// packets that come in a compressed packet may be partial | ||
// so use the copyN function to read the packet header into a | ||
// buffer, since copyN is capable of getting the next compressed | ||
// packet and updating the Conn state with a new compressedReader. | ||
if _, err := c.copyN(b, c.currentPacketReader(), 4); err != nil { | ||
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err) | ||
} else { | ||
// copy was successful so copy the 4 bytes from the buffer to the header | ||
copy(c.header[:4], b.Bytes()[:4]) | ||
} | ||
|
||
length := int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16) | ||
|
@@ -211,7 +255,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error { | |
buf.Grow(length) | ||
} | ||
|
||
if n, err := c.copyN(w, r, int64(length)); err != nil { | ||
if n, err := c.copyN(w, c.currentPacketReader(), int64(length)); err != nil { | ||
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length) | ||
} else if n != int64(length) { | ||
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length) | ||
|
@@ -220,7 +264,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error { | |
return nil | ||
} | ||
|
||
if err = c.ReadPacketTo(w, r); err != nil { | ||
if err = c.ReadPacketTo(w); err != nil { | ||
return errors.Wrap(err, "ReadPacketTo failed") | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there're
in
Conn
. Seems we can directly checkc.compressedReader == nil
as the returned reader forcurrentPacketReader
. AndcompressedReaderActive
always has the same value forc.compressedReader == nil
so we can delete it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I attempted to delete
compressedReaderActive
the tests in theclient
package all began failing when I ran them with compression enabled. I think this is becausecompressedReaderActive
is reset to false inWritePacket
after writing the compressed packet. So I don't think I can delete it, or at least I feel deleting it is out of scope for this PR.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated this PR with your suggestion. It was 2am for me and I wasn't thinking clearly, but after more sleep, I realized I could easily remove the
compressedReaderActive
boolean property fromConn
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take care of your health ❤️