Skip to content

Commit 1eaf0c0

Browse files
committed
re-implement TestReadPacketWrongSequenceID
1 parent 0226235 commit 1eaf0c0

File tree

3 files changed

+41
-34
lines changed

3 files changed

+41
-34
lines changed

connection.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func (mc *mysqlConn) Close() error {
154154
func (mc *mysqlConn) closeContext(ctx context.Context) (err error) {
155155
// Makes Close idempotent
156156
if !mc.closed.Load() {
157-
err = mc.writeCommandPacket(context.Background(), comQuit)
157+
err = mc.writeCommandPacket(ctx, comQuit)
158158
}
159159

160160
mc.cleanup()

packets.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ func (mc *mysqlConn) readPacket(ctx context.Context) ([]byte, error) {
4343
pktLen := int(uint32(mc.data[0]) | uint32(mc.data[1])<<8 | uint32(mc.data[2])<<16)
4444

4545
// check packet sync [8 bit]
46-
if mc.data[3] != mc.sequence {
46+
if seq := mc.data[3]; seq != mc.sequence {
4747
mc.closeContext(ctx)
48-
if mc.data[3] > mc.sequence {
48+
if seq > mc.sequence {
4949
return nil, ErrPktSyncMul
5050
}
5151
return nil, ErrPktSync

packets_test.go

+38-31
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"context"
13+
"io"
1314
"net"
1415
"testing"
1516
)
@@ -52,37 +53,43 @@ func TestReadPacketSingleByte(t *testing.T) {
5253
}
5354
}
5455

55-
// func TestReadPacketWrongSequenceID(t *testing.T) {
56-
// for _, testCase := range []struct {
57-
// ClientSequenceID byte
58-
// ServerSequenceID byte
59-
// ExpectedErr error
60-
// }{
61-
// {
62-
// ClientSequenceID: 1,
63-
// ServerSequenceID: 0,
64-
// ExpectedErr: ErrPktSync,
65-
// },
66-
// {
67-
// ClientSequenceID: 0,
68-
// ServerSequenceID: 0x42,
69-
// ExpectedErr: ErrPktSyncMul,
70-
// },
71-
// } {
72-
// conn, mc := newRWMockConn(testCase.ClientSequenceID)
73-
74-
// conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff}
75-
// _, err := mc.readPacket()
76-
// if err != testCase.ExpectedErr {
77-
// t.Errorf("expected %v, got %v", testCase.ExpectedErr, err)
78-
// }
79-
80-
// // connection should not be returned to the pool in this state
81-
// if mc.IsValid() {
82-
// t.Errorf("expected IsValid() to be false")
83-
// }
84-
// }
85-
// }
56+
func TestReadPacketWrongSequenceID(t *testing.T) {
57+
for _, testCase := range []struct {
58+
ClientSequenceID byte
59+
ServerSequenceID byte
60+
ExpectedErr error
61+
}{
62+
{
63+
ClientSequenceID: 1,
64+
ServerSequenceID: 0,
65+
ExpectedErr: ErrPktSync,
66+
},
67+
{
68+
ClientSequenceID: 0,
69+
ServerSequenceID: 0x42,
70+
ExpectedErr: ErrPktSyncMul,
71+
},
72+
} {
73+
testCase := testCase
74+
75+
conn, mc := newRWMockConn(t, testCase.ClientSequenceID)
76+
go func() {
77+
io.Copy(io.Discard, conn)
78+
}()
79+
go func() {
80+
conn.Write([]byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff})
81+
}()
82+
_, err := mc.readPacket(context.Background())
83+
if err != testCase.ExpectedErr {
84+
t.Errorf(`expected "%v", got "%v"`, testCase.ExpectedErr, err)
85+
}
86+
87+
// connection should not be returned to the pool in this state
88+
if mc.IsValid() {
89+
t.Errorf("expected IsValid() to be false")
90+
}
91+
}
92+
}
8693

8794
// func TestReadPacketSplit(t *testing.T) {
8895
// conn := new(mockConn)

0 commit comments

Comments
 (0)