@@ -179,22 +179,26 @@ func (c *Conn) currentPacketReader() io.Reader {
179
179
}
180
180
}
181
181
182
- func (c * Conn ) copyN (dst io.Writer , src io.Reader , n int64 ) (written int64 , err error ) {
182
+ func (c * Conn ) copyN (dst io.Writer , n int64 ) (int64 , error ) {
183
+ var written int64
184
+
183
185
for n > 0 {
184
186
bcap := cap (c .copyNBuf )
185
187
if int64 (bcap ) > n {
186
188
bcap = int (n )
187
189
}
188
190
buf := c .copyNBuf [:bcap ]
189
191
190
- var rd int
191
- rd , err = io .ReadAtLeast (src , buf , bcap )
192
+ // Call ReadAtLeast with the currentPacketReader as it may change on every iteration
193
+ // of this loop.
194
+ rd , err := io .ReadAtLeast (c .currentPacketReader (), buf , bcap )
192
195
193
196
n -= int64 (rd )
194
197
195
- // if we've read to EOF, and we have compression then advance the sequence number
196
- // and reset the compressed reader to continue reading the remaining bytes
197
- // in the next compressed packet.
198
+ // ReadAtLeast will return EOF or ErrUnexpectedEOF when fewer than the min
199
+ // bytes are read. In this case, and when we have compression then advance
200
+ // the sequence number and reset the compressed reader to continue reading
201
+ // the remaining bytes in the next compressed packet.
198
202
if c .Compression != MYSQL_COMPRESS_NONE && rd < bcap &&
199
203
(goErrors .Is (err , io .ErrUnexpectedEOF ) || goErrors .Is (err , io .EOF )) {
200
204
// we have read to EOF and read an incomplete uncompressed packet
@@ -204,17 +208,14 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
204
208
if c .compressedReader , err = c .newCompressedPacketReader (); err != nil {
205
209
return written , errors .Trace (err )
206
210
}
207
-
208
- // now read the remaining bytes into the buffer containing the first read bytes
209
- rd , err = io .ReadAtLeast (c .currentPacketReader (), buf [rd :], bcap - rd )
210
- n -= int64 (rd )
211
211
}
212
212
213
213
if err != nil {
214
214
return written , errors .Trace (err )
215
215
}
216
216
217
- wr , err := dst .Write (buf )
217
+ // careful to only write from the buffer the number of bytes read
218
+ wr , err := dst .Write (buf [:rd ])
218
219
written += int64 (wr )
219
220
if err != nil {
220
221
return written , errors .Trace (err )
@@ -234,7 +235,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
234
235
// so use the copyN function to read the packet header into a
235
236
// buffer, since copyN is capable of getting the next compressed
236
237
// packet and updating the Conn state with a new compressedReader.
237
- if _ , err := c .copyN (b , c . currentPacketReader (), 4 ); err != nil {
238
+ if _ , err := c .copyN (b , 4 ); err != nil {
238
239
return errors .Wrapf (ErrBadConn , "io.ReadFull(header) failed. err %v" , err )
239
240
} else {
240
241
// copy was successful so copy the 4 bytes from the buffer to the header
@@ -255,7 +256,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
255
256
buf .Grow (length )
256
257
}
257
258
258
- if n , err := c .copyN (w , c . currentPacketReader (), int64 (length )); err != nil {
259
+ if n , err := c .copyN (w , int64 (length )); err != nil {
259
260
return errors .Wrapf (ErrBadConn , "io.CopyN failed. err %v, copied %v, expected %v" , err , n , length )
260
261
} else if n != int64 (length ) {
261
262
return errors .Wrapf (ErrBadConn , "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected" , n , length )
0 commit comments