Skip to content

Commit 9b0f283

Browse files
authored
applied small refactor writeResultSet to minimize code duplication (#612)
1 parent 403f48f commit 9b0f283

File tree

3 files changed

+274
-12
lines changed

3 files changed

+274
-12
lines changed

server/resp.go

+6-12
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,7 @@ func (c *Conn) writeResultset(r *Resultset) error {
125125
return err
126126
}
127127

128-
for _, v := range r.Fields {
129-
data = data[0:4]
130-
data = append(data, v.Dump()...)
131-
if err := c.WritePacket(data); err != nil {
132-
return err
133-
}
134-
}
135-
136-
if err := c.writeEOF(); err != nil {
128+
if err := c.writeFieldList(r.Fields, data); err != nil {
137129
return err
138130
}
139131

@@ -152,8 +144,10 @@ func (c *Conn) writeResultset(r *Resultset) error {
152144
return nil
153145
}
154146

155-
func (c *Conn) writeFieldList(fs []*Field) error {
156-
data := make([]byte, 4, 1024)
147+
func (c *Conn) writeFieldList(fs []*Field, data []byte) error {
148+
if data == nil {
149+
data = make([]byte, 4, 1024)
150+
}
157151

158152
for _, v := range fs {
159153
data = data[0:4]
@@ -189,7 +183,7 @@ func (c *Conn) writeValue(value interface{}) error {
189183
return c.writeOK(v)
190184
}
191185
case []*Field:
192-
return c.writeFieldList(v)
186+
return c.writeFieldList(v, nil)
193187
case *Stmt:
194188
return c.writePrepare(v)
195189
default:

server/resp_test.go

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package server
2+
3+
import (
4+
"errors"
5+
6+
"github.com/go-mysql-org/go-mysql/mysql"
7+
"github.com/go-mysql-org/go-mysql/packet"
8+
mockconn "github.com/go-mysql-org/go-mysql/test_util/conn"
9+
"github.com/pingcap/check"
10+
)
11+
12+
type respConnTestSuite struct{}
13+
14+
var _ = check.Suite(&respConnTestSuite{})
15+
16+
func (t *respConnTestSuite) TestConnWriteOK(c *check.C) {
17+
clientConn := &mockconn.MockConn{}
18+
conn := &Conn{Conn: packet.NewConn(clientConn)}
19+
20+
result := &mysql.Result{
21+
AffectedRows: 1,
22+
InsertId: 2,
23+
}
24+
25+
// write ok with insertid and affectedrows set
26+
err := conn.writeOK(result)
27+
c.Assert(err, check.IsNil)
28+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{3, 0, 0, 0, mysql.OK_HEADER, 1, 2})
29+
30+
// set capability for CLIENT_PROTOCOL_41
31+
conn.SetCapability(mysql.CLIENT_PROTOCOL_41)
32+
conn.SetStatus(mysql.SERVER_QUERY_WAS_SLOW)
33+
err = conn.writeOK(result)
34+
c.Assert(err, check.IsNil)
35+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{7, 0, 0, 1, mysql.OK_HEADER, 1, 2, 0, 8, 0, 0})
36+
}
37+
38+
func (t *respConnTestSuite) TestConnWriteEOF(c *check.C) {
39+
clientConn := &mockconn.MockConn{}
40+
conn := &Conn{Conn: packet.NewConn(clientConn)}
41+
42+
// write regular EOF
43+
err := conn.writeEOF()
44+
c.Assert(err, check.IsNil)
45+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{1, 0, 0, 0, mysql.EOF_HEADER})
46+
47+
// set capability for CLIENT_PROTOCOL_41
48+
conn.SetCapability(mysql.CLIENT_PROTOCOL_41)
49+
conn.SetStatus(mysql.SERVER_MORE_RESULTS_EXISTS)
50+
err = conn.writeEOF()
51+
c.Assert(err, check.IsNil)
52+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{5, 0, 0, 1, mysql.EOF_HEADER,
53+
0, 0, 8, 0})
54+
}
55+
56+
func (t *respConnTestSuite) TestConnWriteError(c *check.C) {
57+
clientConn := &mockconn.MockConn{}
58+
conn := &Conn{Conn: packet.NewConn(clientConn)}
59+
merr := mysql.NewDefaultError(mysql.ER_YES) // nice and short error message
60+
61+
// write regular Error
62+
err := conn.writeError(merr)
63+
c.Assert(err, check.IsNil)
64+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{6, 0, 0, 0, mysql.ERR_HEADER,
65+
235, 3, 89, 69, 83})
66+
67+
// set capability for CLIENT_PROTOCOL_41
68+
conn.SetCapability(mysql.CLIENT_PROTOCOL_41)
69+
err = conn.writeError(merr)
70+
c.Assert(err, check.IsNil)
71+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{12, 0, 0, 1, mysql.ERR_HEADER,
72+
235, 3, 35, 72, 89, 48, 48, 48, 89, 69, 83})
73+
74+
// unknown error
75+
err = conn.writeError(errors.New("test"))
76+
c.Assert(err, check.IsNil)
77+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{13, 0, 0, 2, mysql.ERR_HEADER,
78+
81, 4, 35, 72, 89, 48, 48, 48, 116, 101, 115, 116})
79+
}
80+
81+
func (t *respConnTestSuite) TestConnWriteAuthSwitchRequest(c *check.C) {
82+
clientConn := &mockconn.MockConn{}
83+
conn := &Conn{Conn: packet.NewConn(clientConn)}
84+
85+
err := conn.writeAuthSwitchRequest("test")
86+
c.Assert(err, check.IsNil)
87+
// first 10 bytes are static, then there is a part random, ending with a \0
88+
c.Assert(clientConn.WriteBuffered[:10], check.BytesEquals, []byte{27, 0, 0, 0, mysql.EOF_HEADER,
89+
116, 101, 115, 116, 0})
90+
c.Assert(clientConn.WriteBuffered[len(clientConn.WriteBuffered)-1:], check.BytesEquals, []byte{0})
91+
}
92+
93+
func (t *respConnTestSuite) TestConnReadAuthSwitchRequestResponse(c *check.C) {
94+
clientConn := &mockconn.MockConn{}
95+
conn := &Conn{Conn: packet.NewConn(clientConn)}
96+
97+
// prepare response for \NUL
98+
clientConn.SetResponse([][]byte{{1, 0, 0, 0, 0}})
99+
data, err := conn.readAuthSwitchRequestResponse()
100+
c.Assert(err, check.IsNil)
101+
c.Assert(data, check.BytesEquals, []byte{})
102+
103+
// prepare response for some auth switch data
104+
clientConn.SetResponse([][]byte{{4, 0, 0, 0, 1, 2, 3, 4}})
105+
conn = &Conn{Conn: packet.NewConn(clientConn)}
106+
107+
data, err = conn.readAuthSwitchRequestResponse()
108+
c.Assert(err, check.IsNil)
109+
c.Assert(data, check.BytesEquals, []byte{1, 2, 3, 4})
110+
}
111+
112+
func (t *respConnTestSuite) TestConnWriteAuthMoreDataPubkey(c *check.C) {
113+
clientConn := &mockconn.MockConn{}
114+
conn := &Conn{
115+
Conn: packet.NewConn(clientConn),
116+
serverConf: &Server{
117+
pubKey: []byte{1, 2, 3, 4},
118+
},
119+
}
120+
121+
err := conn.writeAuthMoreDataPubkey()
122+
c.Assert(err, check.IsNil)
123+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{5, 0, 0, 0, mysql.MORE_DATE_HEADER,
124+
1, 2, 3, 4})
125+
}
126+
127+
func (t *respConnTestSuite) TestConnWriteAuthMoreDataFullAuth(c *check.C) {
128+
clientConn := &mockconn.MockConn{}
129+
conn := &Conn{Conn: packet.NewConn(clientConn)}
130+
131+
err := conn.writeAuthMoreDataFullAuth()
132+
c.Assert(err, check.IsNil)
133+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{2, 0, 0, 0, mysql.MORE_DATE_HEADER,
134+
mysql.CACHE_SHA2_FULL_AUTH})
135+
}
136+
137+
func (t *respConnTestSuite) TestConnWriteAuthMoreDataFastAuth(c *check.C) {
138+
clientConn := &mockconn.MockConn{}
139+
conn := &Conn{Conn: packet.NewConn(clientConn)}
140+
141+
err := conn.writeAuthMoreDataFastAuth()
142+
c.Assert(err, check.IsNil)
143+
c.Assert(clientConn.WriteBuffered, check.BytesEquals, []byte{2, 0, 0, 0, mysql.MORE_DATE_HEADER,
144+
mysql.CACHE_SHA2_FAST_AUTH})
145+
}
146+
147+
func (t *respConnTestSuite) TestConnWriteResultset(c *check.C) {
148+
clientConn := &mockconn.MockConn{MultiWrite: true}
149+
conn := &Conn{Conn: packet.NewConn(clientConn)}
150+
151+
r := mysql.NewResultset(0)
152+
153+
// write minimalistic resultset
154+
err := conn.writeResultset(r)
155+
c.Assert(err, check.IsNil)
156+
// column length 0
157+
c.Assert(clientConn.WriteBuffered[:5], check.BytesEquals, []byte{1, 0, 0, 0, 0})
158+
// no fields and an EOF
159+
c.Assert(clientConn.WriteBuffered[5:10], check.BytesEquals, []byte{1, 0, 0, 1, mysql.EOF_HEADER})
160+
// no rows and another EOF
161+
c.Assert(clientConn.WriteBuffered[10:], check.BytesEquals, []byte{1, 0, 0, 2, mysql.EOF_HEADER})
162+
163+
// reset write buffer and fill up the resultset with (little) data
164+
clientConn.WriteBuffered = []byte{}
165+
r, err = mysql.BuildSimpleTextResultset([]string{"a"}, [][]interface{}{{"b"}})
166+
c.Assert(err, check.IsNil)
167+
err = conn.writeResultset(r)
168+
c.Assert(err, check.IsNil)
169+
// column length 1
170+
c.Assert(clientConn.WriteBuffered[:5], check.BytesEquals, []byte{1, 0, 0, 3, 1})
171+
// fields and EOF
172+
c.Assert(clientConn.WriteBuffered[5:32], check.BytesEquals, []byte{23, 0, 0, 4, 3, 100, 101, 102, 0, 0, 0, 1, 'a', 0, 12, 33, 0, 0, 0, 0, 0, 253, 0, 0, 0, 0, 0})
173+
c.Assert(clientConn.WriteBuffered[32:37], check.BytesEquals, []byte{1, 0, 0, 5, mysql.EOF_HEADER})
174+
// rowdata and EOF
175+
c.Assert(clientConn.WriteBuffered[37:43], check.BytesEquals, []byte{2, 0, 0, 6, 1, 'b'})
176+
c.Assert(clientConn.WriteBuffered[43:], check.BytesEquals, []byte{1, 0, 0, 7, mysql.EOF_HEADER})
177+
}
178+
179+
func (t *respConnTestSuite) TestConnWriteFieldList(c *check.C) {
180+
clientConn := &mockconn.MockConn{MultiWrite: true}
181+
conn := &Conn{Conn: packet.NewConn(clientConn)}
182+
183+
r, err := mysql.BuildSimpleTextResultset([]string{"c"}, [][]interface{}{{"d"}})
184+
c.Assert(err, check.IsNil)
185+
err = conn.writeFieldList(r.Fields, nil)
186+
c.Assert(err, check.IsNil)
187+
188+
// column length 1
189+
c.Assert(clientConn.WriteBuffered[:27], check.BytesEquals, []byte{23, 0, 0, 0, 3, 100, 101, 102, 0, 0, 0, 1, 'c', 0, 12, 33, 0, 0, 0, 0, 0, 253, 0, 0, 0, 0, 0})
190+
c.Assert(clientConn.WriteBuffered[27:], check.BytesEquals, []byte{1, 0, 0, 1, mysql.EOF_HEADER})
191+
}

test_util/conn/mockconn.go

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package conn
2+
3+
import (
4+
"errors"
5+
"net"
6+
"time"
7+
)
8+
9+
type mockAddr struct{}
10+
11+
func (m mockAddr) String() string { return "mocking" }
12+
func (m mockAddr) Network() string { return "mocking" }
13+
14+
// MockConn is a simple struct implementing net.Conn that allows us to read what
15+
// was written to it and feed data it will read from
16+
type MockConn struct {
17+
readResponses [][]byte
18+
WriteBuffered []byte
19+
Closed bool
20+
21+
MultiWrite bool
22+
}
23+
24+
func (m *MockConn) SetResponse(r [][]byte) {
25+
m.readResponses = r
26+
}
27+
28+
func (m *MockConn) Read(p []byte) (n int, err error) {
29+
if m.Closed {
30+
return -1, errors.New("connection closed")
31+
}
32+
33+
if len(m.readResponses) == 0 {
34+
return -1, errors.New("no response left")
35+
}
36+
37+
copy(p, m.readResponses[0])
38+
m.readResponses = m.readResponses[1:]
39+
40+
return len(p), nil
41+
}
42+
43+
func (m *MockConn) Write(p []byte) (n int, err error) {
44+
if m.Closed {
45+
return -1, errors.New("connection closed")
46+
}
47+
48+
if m.MultiWrite {
49+
m.WriteBuffered = append(m.WriteBuffered, p...)
50+
} else {
51+
m.WriteBuffered = make([]byte, len(p))
52+
copy(m.WriteBuffered, p)
53+
}
54+
55+
return len(p), nil
56+
}
57+
58+
func (m MockConn) LocalAddr() net.Addr { return mockAddr{} }
59+
func (m MockConn) RemoteAddr() net.Addr { return mockAddr{} }
60+
61+
func (m *MockConn) Close() error {
62+
m.Closed = true
63+
64+
return nil
65+
}
66+
67+
func (m MockConn) SetDeadline(t time.Time) error {
68+
return errors.New("not implemented")
69+
}
70+
71+
func (m MockConn) SetReadDeadline(t time.Time) error {
72+
return errors.New("not implemented")
73+
}
74+
75+
func (m MockConn) SetWriteDeadline(t time.Time) error {
76+
return errors.New("not implemented")
77+
}

0 commit comments

Comments
 (0)