diff --git a/buffer.go b/buffer.go index 19486bd6f..2fa5e4a7e 100644 --- a/buffer.go +++ b/buffer.go @@ -14,7 +14,8 @@ import ( "time" ) -const defaultBufSize = 4096 +const minBufSize = 4096 +const defaultBufSize = 16 * 1024 // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. @@ -37,6 +38,27 @@ func newBuffer(nc net.Conn) buffer { } } +// discard trims b.buf[:b.idx] to prohibit it reused. +// +// This is required by Rows.Close(). +// See https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 +func (b *buffer) discard() { + if len(b.buf)-b.idx >= minBufSize { + b.buf = b.buf[b.idx:] + b.idx = 0 + return + } + + bufSize := defaultBufSize + if bufSize < b.length { + bufSize = b.length + } + newBuf := make([]byte, bufSize) + copy(newBuf, b.buf[b.idx:b.idx+b.length]) + b.buf = newBuf + b.idx = 0 +} + // fill reads into the buffer until at least _need_ bytes are in it func (b *buffer) fill(need int) error { n := b.length @@ -47,11 +69,9 @@ func (b *buffer) fill(need int) error { } // grow buffer if necessary - // TODO: let the buffer shrink again at some point - // Maybe keep the org buf slice and swap back? if need > len(b.buf) { // Round up to the next multiple of the default size - newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) + newBuf := make([]byte, ((need/minBufSize)+1)*minBufSize) copy(newBuf, b.buf) b.buf = newBuf } @@ -129,7 +149,7 @@ func (b *buffer) takeBuffer(length int) ([]byte, error) { } // takeSmallBuffer is shortcut which can be used if length is -// known to be smaller than defaultBufSize. +// known to be smaller than minBufSize. // Only one buffer (total) can be used at a time. func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { if b.length > 0 { diff --git a/rows.go b/rows.go index d3b1e2822..5fc6f23e4 100644 --- a/rows.go +++ b/rows.go @@ -111,6 +111,10 @@ func (rows *mysqlRows) Close() (err error) { return err } + // We can't reuse receive buffer when rows.Close() is called. + // See https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 + mc.buf.discard() + // Remove unread packets from stream if !rows.rs.done { err = mc.readUntilEOF()