@@ -3,6 +3,7 @@ package packet
3
3
import (
4
4
"bufio"
5
5
"bytes"
6
+ "compress/zlib"
6
7
"crypto/rand"
7
8
"crypto/rsa"
8
9
"crypto/sha1"
@@ -12,6 +13,7 @@ import (
12
13
"net"
13
14
"sync"
14
15
16
+ "github.com/DataDog/zstd"
15
17
. "github.com/go-mysql-org/go-mysql/mysql"
16
18
"github.com/go-mysql-org/go-mysql/utils"
17
19
"github.com/pingcap/errors"
@@ -56,6 +58,16 @@ type Conn struct {
56
58
header [4 ]byte
57
59
58
60
Sequence uint8
61
+
62
+ Compression uint8
63
+
64
+ CompressedSequence uint8
65
+
66
+ compressedHeader [7 ]byte
67
+
68
+ compressedReaderActive bool
69
+
70
+ compressedReader io.Reader
59
71
}
60
72
61
73
func NewConn (conn net.Conn ) * Conn {
@@ -94,8 +106,43 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
94
106
utils .BytesBufferPut (buf )
95
107
}()
96
108
97
- if err := c .ReadPacketTo (buf ); err != nil {
98
- return nil , errors .Trace (err )
109
+ if c .Compression != MYSQL_COMPRESS_NONE {
110
+ if ! c .compressedReaderActive {
111
+ if _ , err := io .ReadFull (c .reader , c .compressedHeader [:7 ]); err != nil {
112
+ return nil , errors .Wrapf (ErrBadConn , "io.ReadFull(compressedHeader) failed. err %v" , err )
113
+ }
114
+
115
+ compressedSequence := c .compressedHeader [3 ]
116
+ uncompressedLength := int (uint32 (c .compressedHeader [4 ]) | uint32 (c .compressedHeader [5 ])<< 8 | uint32 (c .compressedHeader [6 ])<< 16 )
117
+ if compressedSequence != c .CompressedSequence {
118
+ return nil , errors .Errorf ("invalid compressed sequence %d != %d" ,
119
+ compressedSequence , c .CompressedSequence )
120
+ }
121
+
122
+ if uncompressedLength > 0 {
123
+ var err error
124
+ switch c .Compression {
125
+ case MYSQL_COMPRESS_ZLIB :
126
+ c .compressedReader , err = zlib .NewReader (c .reader )
127
+ case MYSQL_COMPRESS_ZSTD :
128
+ c .compressedReader = zstd .NewReader (c .reader )
129
+ }
130
+ if err != nil {
131
+ return nil , err
132
+ }
133
+ }
134
+ c .compressedReaderActive = true
135
+ }
136
+ }
137
+
138
+ if c .compressedReader != nil {
139
+ if err := c .ReadPacketTo (buf , c .compressedReader ); err != nil {
140
+ return nil , errors .Trace (err )
141
+ }
142
+ } else {
143
+ if err := c .ReadPacketTo (buf , c .reader ); err != nil {
144
+ return nil , errors .Trace (err )
145
+ }
99
146
}
100
147
101
148
readBytes := buf .Bytes ()
@@ -145,8 +192,8 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
145
192
return written , nil
146
193
}
147
194
148
- func (c * Conn ) ReadPacketTo (w io.Writer ) error {
149
- if _ , err := io .ReadFull (c . reader , c .header [:4 ]); err != nil {
195
+ func (c * Conn ) ReadPacketTo (w io.Writer , r io. Reader ) error {
196
+ if _ , err := io .ReadFull (r , c .header [:4 ]); err != nil {
150
197
return errors .Wrapf (ErrBadConn , "io.ReadFull(header) failed. err %v" , err )
151
198
}
152
199
@@ -164,7 +211,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
164
211
buf .Grow (length )
165
212
}
166
213
167
- if n , err := c .copyN (w , c . reader , int64 (length )); err != nil {
214
+ if n , err := c .copyN (w , r , int64 (length )); err != nil {
168
215
return errors .Wrapf (ErrBadConn , "io.CopyN failed. err %v, copied %v, expected %v" , err , n , length )
169
216
} else if n != int64 (length ) {
170
217
return errors .Wrapf (ErrBadConn , "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected" , n , length )
@@ -173,7 +220,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
173
220
return nil
174
221
}
175
222
176
- if err : = c .ReadPacketTo (w ); err != nil {
223
+ if err = c .ReadPacketTo (w , r ); err != nil {
177
224
return errors .Wrap (err , "ReadPacketTo failed" )
178
225
}
179
226
}
@@ -209,14 +256,93 @@ func (c *Conn) WritePacket(data []byte) error {
209
256
data [2 ] = byte (length >> 16 )
210
257
data [3 ] = c .Sequence
211
258
212
- if n , err := c .Write (data ); err != nil {
213
- return errors .Wrapf (ErrBadConn , "Write failed. err %v" , err )
214
- } else if n != len (data ) {
215
- return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
259
+ switch c .Compression {
260
+ case MYSQL_COMPRESS_NONE :
261
+ if n , err := c .Write (data ); err != nil {
262
+ return errors .Wrapf (ErrBadConn , "Write failed. err %v" , err )
263
+ } else if n != len (data ) {
264
+ return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
265
+ }
266
+ case MYSQL_COMPRESS_ZLIB , MYSQL_COMPRESS_ZSTD :
267
+ if n , err := c .writeCompressed (data ); err != nil {
268
+ return errors .Wrapf (ErrBadConn , "Write failed. err %v" , err )
269
+ } else if n != len (data ) {
270
+ return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
271
+ }
272
+ c .compressedReader = nil
273
+ c .compressedReaderActive = false
274
+ default :
275
+ return errors .Wrapf (ErrBadConn , "Write failed. Unsuppored compression algorithm set" )
276
+ }
277
+
278
+ c .Sequence ++
279
+ return nil
280
+ }
281
+
282
+ func (c * Conn ) writeCompressed (data []byte ) (n int , err error ) {
283
+ var compressedLength , uncompressedLength int
284
+ var payload , compressedPacket bytes.Buffer
285
+ var w io.WriteCloser
286
+ minCompressLength := 50
287
+ compressedHeader := make ([]byte , 7 )
288
+
289
+ switch c .Compression {
290
+ case MYSQL_COMPRESS_ZLIB :
291
+ w , err = zlib .NewWriterLevel (& payload , zlib .HuffmanOnly )
292
+ case MYSQL_COMPRESS_ZSTD :
293
+ w = zstd .NewWriter (& payload )
294
+ }
295
+ if err != nil {
296
+ return 0 , err
297
+ }
298
+
299
+ if len (data ) > minCompressLength {
300
+ uncompressedLength = len (data )
301
+ n , err = w .Write (data )
302
+ if err != nil {
303
+ return 0 , err
304
+ }
305
+ err = w .Close ()
306
+ if err != nil {
307
+ return 0 , err
308
+ }
309
+ }
310
+
311
+ if len (data ) > minCompressLength {
312
+ compressedLength = len (payload .Bytes ())
313
+ } else {
314
+ compressedLength = len (data )
315
+ }
316
+
317
+ c .CompressedSequence = 0
318
+ compressedHeader [0 ] = byte (compressedLength )
319
+ compressedHeader [1 ] = byte (compressedLength >> 8 )
320
+ compressedHeader [2 ] = byte (compressedLength >> 16 )
321
+ compressedHeader [3 ] = c .CompressedSequence
322
+ compressedHeader [4 ] = byte (uncompressedLength )
323
+ compressedHeader [5 ] = byte (uncompressedLength >> 8 )
324
+ compressedHeader [6 ] = byte (uncompressedLength >> 16 )
325
+ _ , err = compressedPacket .Write (compressedHeader )
326
+ if err != nil {
327
+ return 0 , err
328
+ }
329
+ c .CompressedSequence ++
330
+
331
+ if len (data ) > minCompressLength {
332
+ _ , err = compressedPacket .Write (payload .Bytes ())
216
333
} else {
217
- c .Sequence ++
218
- return nil
334
+ n , err = compressedPacket .Write (data )
335
+ }
336
+ if err != nil {
337
+ return 0 , err
219
338
}
339
+
340
+ _ , err = c .Write (compressedPacket .Bytes ())
341
+ if err != nil {
342
+ return 0 , err
343
+ }
344
+
345
+ return n , nil
220
346
}
221
347
222
348
// WriteClearAuthPacket: Client clear text authentication packet
0 commit comments