Skip to content

Commit 7b42091

Browse files
author
Brigitte Lamarche
committed
packets: implemented compression protocol
1 parent 3955978 commit 7b42091

10 files changed

+493
-17
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Aaron Hopkins <go-sql-driver at die.net>
1515
Achille Roussel <achille.roussel at gmail.com>
1616
Arne Hormann <arnehormann at gmail.com>
1717
Asta Xie <xiemengjun at gmail.com>
18+
B Lamarche <blam413 at gmail.com>
1819
Bulat Gaifullin <gaifullinbf at gmail.com>
1920
Carlos Nieto <jose.carlos at menteslibres.net>
2021
Chris Moos <chris at tech9computers.com>

benchmark_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ func BenchmarkInterpolation(b *testing.B) {
224224
maxWriteSize: maxPacketSize - 1,
225225
buf: newBuffer(nil),
226226
}
227+
mc.reader = &mc.buf
227228

228229
args := []driver.Value{
229230
int64(42424242),

compress.go

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
package mysql
2+
3+
import (
4+
"bytes"
5+
"compress/zlib"
6+
"io"
7+
)
8+
9+
const (
10+
minCompressLength = 50
11+
)
12+
13+
type packetReader interface {
14+
readNext(need int) ([]byte, error)
15+
}
16+
17+
type compressedReader struct {
18+
buf packetReader
19+
bytesBuf []byte
20+
mc *mysqlConn
21+
}
22+
23+
type compressedWriter struct {
24+
connWriter io.Writer
25+
mc *mysqlConn
26+
}
27+
28+
func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader {
29+
return &compressedReader{
30+
buf: buf,
31+
bytesBuf: make([]byte, 0),
32+
mc: mc,
33+
}
34+
}
35+
36+
func NewCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter {
37+
return &compressedWriter{
38+
connWriter: connWriter,
39+
mc: mc,
40+
}
41+
}
42+
43+
func (cr *compressedReader) readNext(need int) ([]byte, error) {
44+
for len(cr.bytesBuf) < need {
45+
err := cr.uncompressPacket()
46+
if err != nil {
47+
return nil, err
48+
}
49+
}
50+
51+
data := make([]byte, need)
52+
53+
copy(data, cr.bytesBuf[:len(data)])
54+
55+
cr.bytesBuf = cr.bytesBuf[len(data):]
56+
57+
return data, nil
58+
}
59+
60+
func (cr *compressedReader) uncompressPacket() error {
61+
header, err := cr.buf.readNext(7) // size of compressed header
62+
63+
if err != nil {
64+
return err
65+
}
66+
67+
// compressed header structure
68+
comprLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
69+
uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16)
70+
compressionSequence := uint8(header[3])
71+
72+
if compressionSequence != cr.mc.compressionSequence {
73+
return ErrPktSync
74+
}
75+
76+
cr.mc.compressionSequence++
77+
78+
comprData, err := cr.buf.readNext(comprLength)
79+
if err != nil {
80+
return err
81+
}
82+
83+
// if payload is uncompressed, its length will be specified as zero, and its
84+
// true length is contained in comprLength
85+
if uncompressedLength == 0 {
86+
cr.bytesBuf = append(cr.bytesBuf, comprData...)
87+
return nil
88+
}
89+
90+
// write comprData to a bytes.buffer, then read it using zlib into data
91+
var b bytes.Buffer
92+
b.Write(comprData)
93+
r, err := zlib.NewReader(&b)
94+
95+
if r != nil {
96+
defer r.Close()
97+
}
98+
99+
if err != nil {
100+
return err
101+
}
102+
103+
data := make([]byte, uncompressedLength)
104+
lenRead := 0
105+
106+
// http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate
107+
for lenRead < uncompressedLength {
108+
109+
tmp := data[lenRead:]
110+
111+
n, err := r.Read(tmp)
112+
lenRead += n
113+
114+
if err == io.EOF {
115+
if lenRead < uncompressedLength {
116+
return io.ErrUnexpectedEOF
117+
}
118+
break
119+
}
120+
121+
if err != nil {
122+
return err
123+
}
124+
}
125+
126+
cr.bytesBuf = append(cr.bytesBuf, data...)
127+
128+
return nil
129+
}
130+
131+
func (cw *compressedWriter) Write(data []byte) (int, error) {
132+
// when asked to write an empty packet, do nothing
133+
if len(data) == 0 {
134+
return 0, nil
135+
}
136+
totalBytes := len(data)
137+
138+
length := len(data) - 4
139+
140+
maxPayloadLength := maxPacketSize - 4
141+
142+
for length >= maxPayloadLength {
143+
// cut off a slice of size max payload length
144+
dataSmall := data[:maxPayloadLength]
145+
lenSmall := len(dataSmall)
146+
147+
var b bytes.Buffer
148+
writer := zlib.NewWriter(&b)
149+
_, err := writer.Write(dataSmall)
150+
writer.Close()
151+
if err != nil {
152+
return 0, err
153+
}
154+
155+
err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall)
156+
if err != nil {
157+
return 0, err
158+
}
159+
160+
length -= maxPayloadLength
161+
data = data[maxPayloadLength:]
162+
}
163+
164+
lenSmall := len(data)
165+
166+
// do not compress if packet is too small
167+
if lenSmall < minCompressLength {
168+
err := cw.writeComprPacketToNetwork(data, 0)
169+
if err != nil {
170+
return 0, err
171+
}
172+
173+
return totalBytes, nil
174+
}
175+
176+
var b bytes.Buffer
177+
writer := zlib.NewWriter(&b)
178+
179+
_, err := writer.Write(data)
180+
writer.Close()
181+
182+
if err != nil {
183+
return 0, err
184+
}
185+
186+
err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall)
187+
188+
if err != nil {
189+
return 0, err
190+
}
191+
return totalBytes, nil
192+
}
193+
194+
func (cw *compressedWriter) writeComprPacketToNetwork(data []byte, uncomprLength int) error {
195+
data = append([]byte{0, 0, 0, 0, 0, 0, 0}, data...)
196+
197+
comprLength := len(data) - 7
198+
199+
// compression header
200+
data[0] = byte(0xff & comprLength)
201+
data[1] = byte(0xff & (comprLength >> 8))
202+
data[2] = byte(0xff & (comprLength >> 16))
203+
204+
data[3] = cw.mc.compressionSequence
205+
206+
//this value is never greater than maxPayloadLength
207+
data[4] = byte(0xff & uncomprLength)
208+
data[5] = byte(0xff & (uncomprLength >> 8))
209+
data[6] = byte(0xff & (uncomprLength >> 16))
210+
211+
if _, err := cw.connWriter.Write(data); err != nil {
212+
return err
213+
}
214+
215+
cw.mc.compressionSequence++
216+
return nil
217+
}

0 commit comments

Comments
 (0)