Skip to content

Commit d2ecd57

Browse files
committed
simplify
1 parent 0bc8145 commit d2ecd57

File tree

4 files changed

+44
-66
lines changed

4 files changed

+44
-66
lines changed

buffer.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ import (
1515
const defaultBufSize = 4096
1616
const maxCachedBufSize = 256 * 1024
1717

18-
// readwriteFunc is a function that compatible with io.Reader and io.Writer.
19-
// We use this function type instead of io.ReadWriter because we want to
20-
// just pass mc.readWithTimeout or mc.writeWithTimeout functions.
21-
type readwriteFunc func([]byte) (int, error)
18+
// readerFunc is a function that compatible with io.Reader.
19+
// We use this function type instead of io.Reader because we want to
20+
// just pass mc.readWithTimeout.
21+
type readerFunc func([]byte) (int, error)
2222

2323
// A buffer which is used for both reading and writing.
2424
// This is possible since communication on each connection is synchronous.
@@ -43,7 +43,7 @@ func (b *buffer) busy() bool {
4343
}
4444

4545
// fill reads into the read buffer until at least _need_ bytes are in it.
46-
func (b *buffer) fill(need int, r readwriteFunc) error {
46+
func (b *buffer) fill(need int, r readerFunc) error {
4747
// we'll move the contents of the current buffer to dest before filling it.
4848
dest := b.cachedBuf
4949

@@ -86,7 +86,7 @@ func (b *buffer) fill(need int, r readwriteFunc) error {
8686

8787
// returns next N bytes from buffer.
8888
// The returned slice is only guaranteed to be valid until the next read
89-
func (b *buffer) readNext(need int, r readwriteFunc) ([]byte, error) {
89+
func (b *buffer) readNext(need int, r readerFunc) ([]byte, error) {
9090
if len(b.buf) < need {
9191
// refill
9292
if err := b.fill(need, r); err != nil {

compress.go

+26-47
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func init() {
3636
}
3737
}
3838

39-
func zDecompress(src, dst []byte) (int, error) {
39+
func zDecompress(src []byte, dst *bytes.Buffer) (int, error) {
4040
br := bytes.NewReader(src)
4141
var zr io.ReadCloser
4242
var err error
@@ -51,27 +51,11 @@ func zDecompress(src, dst []byte) (int, error) {
5151
return 0, err
5252
}
5353
}
54-
defer func() {
55-
zr.Close()
56-
zrPool.Put(zr)
57-
}()
5854

59-
lenRead := 0
60-
size := len(dst)
61-
62-
for lenRead < size {
63-
n, err := zr.Read(dst[lenRead:])
64-
lenRead += n
65-
66-
if err == io.EOF {
67-
if lenRead < size {
68-
return lenRead, io.ErrUnexpectedEOF
69-
}
70-
} else if err != nil {
71-
return lenRead, err
72-
}
73-
}
74-
return lenRead, nil
55+
n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again.
56+
err = zr.Close() // zr.Close() may return chuecksum error.
57+
zrPool.Put(zr)
58+
return int(n), err
7559
}
7660

7761
func zCompress(src []byte, dst io.Writer) error {
@@ -100,7 +84,7 @@ func (c *compIO) reset() {
10084
c.buff.Reset()
10185
}
10286

103-
func (c *compIO) readNext(need int, r readwriteFunc) ([]byte, error) {
87+
func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) {
10488
for c.buff.Len() < need {
10589
if err := c.readCompressedPacket(r); err != nil {
10690
return nil, err
@@ -110,7 +94,7 @@ func (c *compIO) readNext(need int, r readwriteFunc) ([]byte, error) {
11094
return data[:need:need], nil // prevent caller writes into c.buff
11195
}
11296

113-
func (c *compIO) readCompressedPacket(r readwriteFunc) error {
97+
func (c *compIO) readCompressedPacket(r readerFunc) error {
11498
header, err := c.mc.buf.readNext(7, r) // size of compressed header
11599
if err != nil {
116100
return err
@@ -121,19 +105,17 @@ func (c *compIO) readCompressedPacket(r readwriteFunc) error {
121105
comprLength := getUint24(header[0:3])
122106
compressionSequence := uint8(header[3])
123107
uncompressedLength := getUint24(header[4:7])
124-
if debugTrace {
108+
if debug {
125109
fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n",
126110
comprLength, uncompressedLength, compressionSequence, c.mc.sequence)
127111
}
128-
if compressionSequence != c.mc.sequence {
129-
// return ErrPktSync
130-
// server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
131-
// before receiving all packets from client. In this case, seqnr is younger than expected.
132-
if debugTrace {
133-
fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v",
134-
c.mc.sequence, compressionSequence)
135-
}
136-
// TODO(methane): report error when the packet is not an error packet.
112+
// Do not return ErrPktSync here.
113+
// Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
114+
// before receiving all packets from client. In this case, seqnr is younger than expected.
115+
// NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it.
116+
if debug && compressionSequence != c.mc.sequence {
117+
fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v",
118+
c.mc.sequence, compressionSequence)
137119
}
138120
c.mc.sequence = compressionSequence + 1
139121
c.mc.compressSequence = c.mc.sequence
@@ -152,31 +134,29 @@ func (c *compIO) readCompressedPacket(r readwriteFunc) error {
152134

153135
// use existing capacity in bytesBuf if possible
154136
c.buff.Grow(uncompressedLength)
155-
dec := c.buff.AvailableBuffer()[:uncompressedLength]
156-
lenRead, err := zDecompress(comprData, dec)
137+
nread, err := zDecompress(comprData, &c.buff)
157138
if err != nil {
158139
return err
159140
}
160-
if lenRead != uncompressedLength {
141+
if nread != uncompressedLength {
161142
return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d",
162-
uncompressedLength, lenRead)
143+
uncompressedLength, nread)
163144
}
164-
c.buff.Write(dec) // fast copy. See bytes.Buffer.AvailableBuffer() doc.
165145
return nil
166146
}
167147

148+
const minCompressLength = 150
168149
const maxPayloadLen = maxPacketSize - 4
169150

170151
// writePackets sends one or some packets with compression.
171152
// Use this instead of mc.netConn.Write() when mc.compress is true.
172153
func (c *compIO) writePackets(packets []byte) (int, error) {
173154
totalBytes := len(packets)
174-
dataLen := len(packets)
175155
blankHeader := make([]byte, 7)
176156
buf := &c.buff
177157

178-
for dataLen > 0 {
179-
payloadLen := min(maxPayloadLen, dataLen)
158+
for len(packets) > 0 {
159+
payloadLen := min(maxPayloadLen, len(packets))
180160
payload := packets[:payloadLen]
181161
uncompressedLen := payloadLen
182162

@@ -190,8 +170,8 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
190170
} else {
191171
zCompress(payload, buf)
192172
// do not compress if compressed data is larger than uncompressed data
193-
// I intentionally miss 7 byte header in the buf; compress should compress more than 7 bytes.
194-
if buf.Len() > uncompressedLen {
173+
// I intentionally miss 7 byte header in the buf; compress more than 7 bytes.
174+
if buf.Len() >= uncompressedLen {
195175
buf.Reset()
196176
buf.Write(blankHeader)
197177
buf.Write(payload)
@@ -204,7 +184,6 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
204184
// up compressed bytes that is returned by underlying Write().
205185
return totalBytes - len(packets) + n, err
206186
}
207-
dataLen -= payloadLen
208187
packets = packets[payloadLen:]
209188
}
210189

@@ -216,7 +195,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
216195
func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) {
217196
mc := c.mc
218197
comprLength := len(data) - 7
219-
if debugTrace {
198+
if debug {
220199
fmt.Printf(
221200
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
222201
comprLength, uncompressedLen, mc.compressSequence)
@@ -227,8 +206,8 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, e
227206
data[3] = mc.compressSequence
228207
putUint24(data[4:7], uncompressedLen)
229208

230-
if n, err := mc.writeWithTimeout(data); err != nil {
231-
// mc.log("writing compressed packet:", err)
209+
n, err := mc.writeWithTimeout(data)
210+
if err != nil {
232211
return n, err
233212
}
234213

const.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@ package mysql
1111
import "runtime"
1212

1313
const (
14-
debugTrace = false // for debugging wire protocol.
14+
debug = false // for debugging. Set true only in development.
1515

1616
defaultAuthPlugin = "mysql_native_password"
1717
defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355
1818
minProtocolVersion = 10
1919
maxPacketSize = 1<<24 - 1
2020
timeFormat = "2006-01-02 15:04:05.999999"
21-
minCompressLength = 150
2221

2322
// Connection attributes
2423
// See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available

packets.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
3030
var prevData []byte
3131
invalid := false
3232

33-
readFunc := mc.buf.readNext
33+
readNext := mc.buf.readNext
3434
if mc.compress {
35-
readFunc = mc.compIO.readNext
35+
readNext = mc.compIO.readNext
3636
}
3737

3838
for {
3939
// read packet header
40-
data, err := readFunc(4, mc.readWithTimeout)
40+
data, err := readNext(4, mc.readWithTimeout)
4141
if err != nil {
4242
mc.close()
4343
if cerr := mc.canceled.Value(); cerr != nil {
@@ -49,19 +49,19 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
4949

5050
// packet length [24 bit]
5151
pktLen := getUint24(data[:3])
52-
seqNr := data[3]
52+
seq := data[3]
5353

5454
if mc.compress {
5555
// MySQL and MariaDB doesn't check packet nr in compressed packet.
56-
if debugTrace && seqNr != mc.compressSequence {
56+
if debug && seq != mc.compressSequence {
5757
fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v",
58-
mc.compressSequence, seqNr)
58+
mc.compressSequence, seq)
5959
}
60-
mc.compressSequence = seqNr + 1
60+
mc.compressSequence = seq + 1
6161
} else {
6262
// check packet sync [8 bit]
63-
if seqNr != mc.sequence {
64-
mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seqNr))
63+
if seq != mc.sequence {
64+
mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq))
6565
// For large packets, we stop reading as soon as sync error.
6666
if len(prevData) > 0 {
6767
mc.close()
@@ -86,7 +86,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
8686
}
8787

8888
// read packet body [pktLen bytes]
89-
data, err = readFunc(pktLen, mc.readWithTimeout)
89+
data, err = readNext(pktLen, mc.readWithTimeout)
9090
if err != nil {
9191
mc.close()
9292
if cerr := mc.canceled.Value(); cerr != nil {
@@ -136,7 +136,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
136136
data[3] = mc.sequence
137137

138138
// Write packet
139-
if debugTrace {
139+
if debug {
140140
fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence)
141141
}
142142

0 commit comments

Comments
 (0)