Skip to content

Commit afbf75a

Browse files
committed
Invalidate connection on busy buffer
It's possible for a connection to enter a state where it is returned to the pool with unread data in the buffer. This can happen when an `ErrPktSync` error is encountered because the packet payload is left unread. When the connection is taken from the pool again, any futher queries will results in an `ErrBusyBuffer` error. The connection can be returned to the pool multiple times before it times out, so this can result in many, many errors. To fix this, this commit updates the `IsValid()` method to return false if the buffer is not empty. The `database/sql` package will use this to decide whether or not to return the connection to the pool, so returning false means that the connection will be discarded.
1 parent f20b286 commit afbf75a

File tree

3 files changed

+51
-28
lines changed

3 files changed

+51
-28
lines changed

buffer.go

+9-3
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func (b *buffer) takeBuffer(length int) ([]byte, error) {
154154
// known to be smaller than defaultBufSize.
155155
// Only one buffer (total) can be used at a time.
156156
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
157-
if b.length > 0 {
157+
if !b.empty() {
158158
return nil, ErrBusyBuffer
159159
}
160160
return b.buf[:length], nil
@@ -165,18 +165,24 @@ func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
165165
// cap and len of the returned buffer will be equal.
166166
// Only one buffer (total) can be used at a time.
167167
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
168-
if b.length > 0 {
168+
if !b.empty() {
169169
return nil, ErrBusyBuffer
170170
}
171171
return b.buf, nil
172172
}
173173

174174
// store stores buf, an updated buffer, if its suitable to do so.
175175
func (b *buffer) store(buf []byte) error {
176-
if b.length > 0 {
176+
if !b.empty() {
177177
return ErrBusyBuffer
178178
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
179179
b.buf = buf[:cap(buf)]
180180
}
181181
return nil
182182
}
183+
184+
// empty returns true if the buffer is empty and does not contain
185+
// any unread data.
186+
func (b *buffer) empty() bool {
187+
return b.length == 0
188+
}

connection.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -646,5 +646,13 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
646646
// IsValid implements driver.Validator interface
647647
// (From Go 1.15)
648648
func (mc *mysqlConn) IsValid() bool {
649-
return !mc.closed.Load()
649+
if mc.closed.Load() {
650+
return false
651+
}
652+
653+
if !mc.buf.empty() {
654+
return false
655+
}
656+
657+
return true
650658
}

packets_test.go

+33-24
Original file line numberDiff line numberDiff line change
@@ -128,30 +128,39 @@ func TestReadPacketSingleByte(t *testing.T) {
128128
}
129129

130130
func TestReadPacketWrongSequenceID(t *testing.T) {
131-
conn := new(mockConn)
132-
mc := &mysqlConn{
133-
buf: newBuffer(conn),
134-
}
135-
136-
// too low sequence id
137-
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
138-
conn.maxReads = 1
139-
mc.sequence = 1
140-
_, err := mc.readPacket()
141-
if err != ErrPktSync {
142-
t.Errorf("expected ErrPktSync, got %v", err)
143-
}
144-
145-
// reset
146-
conn.reads = 0
147-
mc.sequence = 0
148-
mc.buf = newBuffer(conn)
149-
150-
// too high sequence id
151-
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
152-
_, err = mc.readPacket()
153-
if err != ErrPktSyncMul {
154-
t.Errorf("expected ErrPktSyncMul, got %v", err)
131+
for _, testCase := range []struct {
132+
ClientSequenceID byte
133+
ServerSequenceID byte
134+
ExpectedErr error
135+
}{
136+
{
137+
ClientSequenceID: 1,
138+
ServerSequenceID: 0,
139+
ExpectedErr: ErrPktSync,
140+
},
141+
{
142+
ClientSequenceID: 0,
143+
ServerSequenceID: 0x42,
144+
ExpectedErr: ErrPktSyncMul,
145+
},
146+
} {
147+
conn := new(mockConn)
148+
mc := mysqlConn{
149+
buf: newBuffer(conn),
150+
closech: make(chan struct{}),
151+
sequence: testCase.ClientSequenceID,
152+
}
153+
154+
conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff}
155+
_, err := mc.readPacket()
156+
if err != testCase.ExpectedErr {
157+
t.Errorf("expected %v, got %v", testCase.ExpectedErr, err)
158+
}
159+
160+
// connection should not be returned to the pool in this state
161+
if mc.IsValid() {
162+
t.Errorf("expected IsValid() to be false")
163+
}
155164
}
156165
}
157166

0 commit comments

Comments
 (0)