Skip to content

Commit e372b06

Browse files
committed
addressing PR feedback
1 parent 5970c88 commit e372b06

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

packet/conn.go

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,22 +179,26 @@ func (c *Conn) currentPacketReader() io.Reader {
179179
}
180180
}
181181

182-
func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
182+
func (c *Conn) copyN(dst io.Writer, n int64) (int64, error) {
183+
var written int64
184+
183185
for n > 0 {
184186
bcap := cap(c.copyNBuf)
185187
if int64(bcap) > n {
186188
bcap = int(n)
187189
}
188190
buf := c.copyNBuf[:bcap]
189191

190-
var rd int
191-
rd, err = io.ReadAtLeast(src, buf, bcap)
192+
// Call ReadAtLeast with the currentPacketReader as it may change on every iteration
193+
// of this loop.
194+
rd, err := io.ReadAtLeast(c.currentPacketReader(), buf, bcap)
192195

193196
n -= int64(rd)
194197

195-
// if we've read to EOF, and we have compression then advance the sequence number
196-
// and reset the compressed reader to continue reading the remaining bytes
197-
// in the next compressed packet.
198+
// ReadAtLeast will return EOF or ErrUnexpectedEOF when fewer than the min
199+
// bytes are read. In this case, and when we have compression then advance
200+
// the sequence number and reset the compressed reader to continue reading
201+
// the remaining bytes in the next compressed packet.
198202
if c.Compression != MYSQL_COMPRESS_NONE && rd < bcap &&
199203
(goErrors.Is(err, io.ErrUnexpectedEOF) || goErrors.Is(err, io.EOF)) {
200204
// we have read to EOF and read an incomplete uncompressed packet
@@ -204,17 +208,14 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
204208
if c.compressedReader, err = c.newCompressedPacketReader(); err != nil {
205209
return written, errors.Trace(err)
206210
}
207-
208-
// now read the remaining bytes into the buffer containing the first read bytes
209-
rd, err = io.ReadAtLeast(c.currentPacketReader(), buf[rd:], bcap-rd)
210-
n -= int64(rd)
211211
}
212212

213213
if err != nil {
214214
return written, errors.Trace(err)
215215
}
216216

217-
wr, err := dst.Write(buf)
217+
// careful to only write from the buffer the number of bytes read
218+
wr, err := dst.Write(buf[:rd])
218219
written += int64(wr)
219220
if err != nil {
220221
return written, errors.Trace(err)
@@ -234,7 +235,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
234235
// so use the copyN function to read the packet header into a
235236
// buffer, since copyN is capable of getting the next compressed
236237
// packet and updating the Conn state with a new compressedReader.
237-
if _, err := c.copyN(b, c.currentPacketReader(), 4); err != nil {
238+
if _, err := c.copyN(b, 4); err != nil {
238239
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
239240
} else {
240241
// copy was successful so copy the 4 bytes from the buffer to the header
@@ -255,7 +256,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
255256
buf.Grow(length)
256257
}
257258

258-
if n, err := c.copyN(w, c.currentPacketReader(), int64(length)); err != nil {
259+
if n, err := c.copyN(w, int64(length)); err != nil {
259260
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
260261
} else if n != int64(length) {
261262
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)

0 commit comments

Comments
 (0)