@@ -36,7 +36,7 @@ func init() {
36
36
}
37
37
}
38
38
39
- func zDecompress (src , dst []byte ) (int , error ) {
39
+ func zDecompress (src []byte , dst * bytes. Buffer ) (int , error ) {
40
40
br := bytes .NewReader (src )
41
41
var zr io.ReadCloser
42
42
var err error
@@ -51,27 +51,11 @@ func zDecompress(src, dst []byte) (int, error) {
51
51
return 0 , err
52
52
}
53
53
}
54
- defer func () {
55
- zr .Close ()
56
- zrPool .Put (zr )
57
- }()
58
54
59
- lenRead := 0
60
- size := len (dst )
61
-
62
- for lenRead < size {
63
- n , err := zr .Read (dst [lenRead :])
64
- lenRead += n
65
-
66
- if err == io .EOF {
67
- if lenRead < size {
68
- return lenRead , io .ErrUnexpectedEOF
69
- }
70
- } else if err != nil {
71
- return lenRead , err
72
- }
73
- }
74
- return lenRead , nil
55
+ n , _ := dst .ReadFrom (zr ) // ignore err because zr.Close() will return it again.
56
+ err = zr .Close () // zr.Close() may return chuecksum error.
57
+ zrPool .Put (zr )
58
+ return int (n ), err
75
59
}
76
60
77
61
func zCompress (src []byte , dst io.Writer ) error {
@@ -100,7 +84,7 @@ func (c *compIO) reset() {
100
84
c .buff .Reset ()
101
85
}
102
86
103
- func (c * compIO ) readNext (need int , r readwriteFunc ) ([]byte , error ) {
87
+ func (c * compIO ) readNext (need int , r readerFunc ) ([]byte , error ) {
104
88
for c .buff .Len () < need {
105
89
if err := c .readCompressedPacket (r ); err != nil {
106
90
return nil , err
@@ -110,7 +94,7 @@ func (c *compIO) readNext(need int, r readwriteFunc) ([]byte, error) {
110
94
return data [:need :need ], nil // prevent caller writes into c.buff
111
95
}
112
96
113
- func (c * compIO ) readCompressedPacket (r readwriteFunc ) error {
97
+ func (c * compIO ) readCompressedPacket (r readerFunc ) error {
114
98
header , err := c .mc .buf .readNext (7 , r ) // size of compressed header
115
99
if err != nil {
116
100
return err
@@ -121,19 +105,17 @@ func (c *compIO) readCompressedPacket(r readwriteFunc) error {
121
105
comprLength := getUint24 (header [0 :3 ])
122
106
compressionSequence := uint8 (header [3 ])
123
107
uncompressedLength := getUint24 (header [4 :7 ])
124
- if debugTrace {
108
+ if debug {
125
109
fmt .Printf ("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n " ,
126
110
comprLength , uncompressedLength , compressionSequence , c .mc .sequence )
127
111
}
128
- if compressionSequence != c .mc .sequence {
129
- // return ErrPktSync
130
- // server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
131
- // before receiving all packets from client. In this case, seqnr is younger than expected.
132
- if debugTrace {
133
- fmt .Printf ("WARN: unexpected cmpress seq nr: expected %v, got %v" ,
134
- c .mc .sequence , compressionSequence )
135
- }
136
- // TODO(methane): report error when the packet is not an error packet.
112
+ // Do not return ErrPktSync here.
113
+ // Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
114
+ // before receiving all packets from client. In this case, seqnr is younger than expected.
115
+ // NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it.
116
+ if debug && compressionSequence != c .mc .sequence {
117
+ fmt .Printf ("WARN: unexpected cmpress seq nr: expected %v, got %v" ,
118
+ c .mc .sequence , compressionSequence )
137
119
}
138
120
c .mc .sequence = compressionSequence + 1
139
121
c .mc .compressSequence = c .mc .sequence
@@ -152,31 +134,29 @@ func (c *compIO) readCompressedPacket(r readwriteFunc) error {
152
134
153
135
// use existing capacity in bytesBuf if possible
154
136
c .buff .Grow (uncompressedLength )
155
- dec := c .buff .AvailableBuffer ()[:uncompressedLength ]
156
- lenRead , err := zDecompress (comprData , dec )
137
+ nread , err := zDecompress (comprData , & c .buff )
157
138
if err != nil {
158
139
return err
159
140
}
160
- if lenRead != uncompressedLength {
141
+ if nread != uncompressedLength {
161
142
return fmt .Errorf ("invalid compressed packet: uncompressed length in header is %d, actual %d" ,
162
- uncompressedLength , lenRead )
143
+ uncompressedLength , nread )
163
144
}
164
- c .buff .Write (dec ) // fast copy. See bytes.Buffer.AvailableBuffer() doc.
165
145
return nil
166
146
}
167
147
148
+ const minCompressLength = 150
168
149
const maxPayloadLen = maxPacketSize - 4
169
150
170
151
// writePackets sends one or some packets with compression.
171
152
// Use this instead of mc.netConn.Write() when mc.compress is true.
172
153
func (c * compIO ) writePackets (packets []byte ) (int , error ) {
173
154
totalBytes := len (packets )
174
- dataLen := len (packets )
175
155
blankHeader := make ([]byte , 7 )
176
156
buf := & c .buff
177
157
178
- for dataLen > 0 {
179
- payloadLen := min (maxPayloadLen , dataLen )
158
+ for len ( packets ) > 0 {
159
+ payloadLen := min (maxPayloadLen , len ( packets ) )
180
160
payload := packets [:payloadLen ]
181
161
uncompressedLen := payloadLen
182
162
@@ -190,8 +170,8 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
190
170
} else {
191
171
zCompress (payload , buf )
192
172
// do not compress if compressed data is larger than uncompressed data
193
- // I intentionally miss 7 byte header in the buf; compress should compress more than 7 bytes.
194
- if buf .Len () > uncompressedLen {
173
+ // I intentionally miss 7 byte header in the buf; compress more than 7 bytes.
174
+ if buf .Len () >= uncompressedLen {
195
175
buf .Reset ()
196
176
buf .Write (blankHeader )
197
177
buf .Write (payload )
@@ -204,7 +184,6 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
204
184
// up compressed bytes that is returned by underlying Write().
205
185
return totalBytes - len (packets ) + n , err
206
186
}
207
- dataLen -= payloadLen
208
187
packets = packets [payloadLen :]
209
188
}
210
189
@@ -216,7 +195,7 @@ func (c *compIO) writePackets(packets []byte) (int, error) {
216
195
func (c * compIO ) writeCompressedPacket (data []byte , uncompressedLen int ) (int , error ) {
217
196
mc := c .mc
218
197
comprLength := len (data ) - 7
219
- if debugTrace {
198
+ if debug {
220
199
fmt .Printf (
221
200
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v" ,
222
201
comprLength , uncompressedLen , mc .compressSequence )
@@ -227,8 +206,8 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, e
227
206
data [3 ] = mc .compressSequence
228
207
putUint24 (data [4 :7 ], uncompressedLen )
229
208
230
- if n , err := mc .writeWithTimeout (data ); err != nil {
231
- // mc.log("writing compressed packet:", err)
209
+ n , err := mc .writeWithTimeout (data )
210
+ if err != nil {
232
211
return n , err
233
212
}
234
213
0 commit comments