diff --git a/AUTHORS b/AUTHORS index 14e8398fd..291c0f2fd 100644 --- a/AUTHORS +++ b/AUTHORS @@ -17,6 +17,7 @@ Alexey Palazhchenko Andrew Reid Arne Hormann Asta Xie +B Lamarche Bulat Gaifullin Carlos Nieto Chris Moos diff --git a/benchmark_go18_test.go b/benchmark_go18_test.go index d6a7e9d6e..5522ab9cf 100644 --- a/benchmark_go18_test.go +++ b/benchmark_go18_test.go @@ -42,7 +42,7 @@ func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkQueryContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -78,7 +78,7 @@ func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkExecContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, diff --git a/benchmark_test.go b/benchmark_test.go index c1de8672b..af7b3971c 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -43,9 +43,13 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { return stmt } -func initDB(b *testing.B, queries ...string) *sql.DB { +func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + comprStr := "" + if useCompression { + comprStr = "&compress=1" + } + db := tb.checkDB(sql.Open("mysql", dsn+comprStr)) for _, query := range queries { if _, err := db.Exec(query); err != nil { b.Fatalf("error on %q: %v", query, err) @@ -57,10 +61,19 @@ func initDB(b *testing.B, queries ...string) *sql.DB { const concurrencyLevel = 10 func BenchmarkQuery(b *testing.B) { + benchmarkQueryHelper(b, false) +} + +func BenchmarkQueryCompression(b *testing.B) { + benchmarkQueryHelper(b, true) +} + +func benchmarkQueryHelper(b *testing.B, compr bool) { + tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := initDB(b, + db := initDB(b, compr, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -218,9 +231,11 @@ func BenchmarkInterpolation(b *testing.B) { }, maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, - buf: newBuffer(nil), } + buf := newBuffer(nil) + mc.reader = &buf + args := []driver.Value{ int64(42424242), float64(math.Pi), diff --git a/buffer.go b/buffer.go index 2001feacd..82ffe5197 100644 --- a/buffer.go +++ b/buffer.go @@ -109,7 +109,11 @@ func (b *buffer) readNext(need int) ([]byte, error) { // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { +func (b *buffer) reuseBuffer(length int) []byte { + if length == -1 { + return b.takeCompleteBuffer() + } + if b.length > 0 { return nil } @@ -126,16 +130,6 @@ func (b *buffer) takeBuffer(length int) []byte { return make([]byte, length) } -// shortcut which can be used if the requested buffer is guaranteed to be -// smaller than defaultBufSize -// Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { - if b.length == 0 { - return b.buf[:length] - } - return nil -} - // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. // Only one buffer (total) can be used at a time. diff --git a/compress.go b/compress.go new file mode 100644 index 000000000..6c45ebeeb --- /dev/null +++ b/compress.go @@ -0,0 +1,236 @@ +package mysql + +import ( + "bytes" + "compress/zlib" + "io" +) + +const ( + minCompressLength = 50 +) + +type compressedReader struct { + buf packetReader + bytesBuf []byte + mc *mysqlConn + zr io.ReadCloser +} + +type compressedWriter struct { + connWriter io.Writer + mc *mysqlConn + zw *zlib.Writer +} + +func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { + return &compressedReader{ + buf: buf, + bytesBuf: make([]byte, 0), + mc: mc, + } +} + +func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { + return &compressedWriter{ + connWriter: connWriter, + mc: mc, + zw: zlib.NewWriter(new(bytes.Buffer)), + } +} + +func (cr *compressedReader) readNext(need int) ([]byte, error) { + for len(cr.bytesBuf) < need { + err := cr.uncompressPacket() + if err != nil { + return nil, err + } + } + + data := cr.bytesBuf[:need] + cr.bytesBuf = cr.bytesBuf[need:] + return data, nil +} + +func (cr *compressedReader) reuseBuffer(length int) []byte { + return cr.buf.reuseBuffer(length) +} + +func (cr *compressedReader) uncompressPacket() error { + header, err := cr.buf.readNext(7) // size of compressed header + + if err != nil { + return err + } + + // compressed header structure + comprLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) + compressionSequence := uint8(header[3]) + + if compressionSequence != cr.mc.compressionSequence { + return ErrPktSync + } + + cr.mc.compressionSequence++ + + comprData, err := cr.buf.readNext(comprLength) + if err != nil { + return err + } + + // if payload is uncompressed, its length will be specified as zero, and its + // true length is contained in comprLength + if uncompressedLength == 0 { + cr.bytesBuf = append(cr.bytesBuf, comprData...) + return nil + } + + // write comprData to a bytes.buffer, then read it using zlib into data + br := bytes.NewReader(comprData) + + if cr.zr == nil { + cr.zr, err = zlib.NewReader(br) + } else { + err = cr.zr.(zlib.Resetter).Reset(br, nil) + } + + if err != nil { + return err + } + + defer cr.zr.Close() + + // use existing capacity in bytesBuf if possible + offset := len(cr.bytesBuf) + if cap(cr.bytesBuf)-offset < uncompressedLength { + old := cr.bytesBuf + cr.bytesBuf = make([]byte, offset, offset+uncompressedLength) + copy(cr.bytesBuf, old) + } + + data := cr.bytesBuf[offset : offset+uncompressedLength] + + lenRead := 0 + + // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate + for lenRead < uncompressedLength { + n, err := cr.zr.Read(data[lenRead:]) + lenRead += n + + if err == io.EOF { + if lenRead < uncompressedLength { + return io.ErrUnexpectedEOF + } + break + } + + if err != nil { + return err + } + } + + cr.bytesBuf = append(cr.bytesBuf, data...) + + return nil +} + +func (cw *compressedWriter) Write(data []byte) (int, error) { + // when asked to write an empty packet, do nothing + if len(data) == 0 { + return 0, nil + } + + totalBytes := len(data) + length := len(data) - 4 + maxPayloadLength := maxPacketSize - 4 + blankHeader := make([]byte, 7) + + for length >= maxPayloadLength { + payload := data[:maxPayloadLength] + payloadLen := len(payload) + + bytesBuf := &bytes.Buffer{} + bytesBuf.Write(blankHeader) + cw.zw.Reset(bytesBuf) + _, err := cw.zw.Write(payload) + if err != nil { + return 0, err + } + cw.zw.Close() + + // if compression expands the payload, do not compress + compressedPayload := bytesBuf.Bytes() + if len(compressedPayload) > maxPayloadLength { + compressedPayload = append(blankHeader, payload...) + payloadLen = 0 + } + + err = cw.writeToNetwork(compressedPayload, payloadLen) + + if err != nil { + return 0, err + } + + length -= maxPayloadLength + data = data[maxPayloadLength:] + } + + payloadLen := len(data) + + // do not attempt compression if packet is too small + if payloadLen < minCompressLength { + err := cw.writeToNetwork(append(blankHeader, data...), 0) + if err != nil { + return 0, err + } + return totalBytes, nil + } + + bytesBuf := &bytes.Buffer{} + bytesBuf.Write(blankHeader) + cw.zw.Reset(bytesBuf) + _, err := cw.zw.Write(data) + if err != nil { + return 0, err + } + cw.zw.Close() + + compressedPayload := bytesBuf.Bytes() + + if len(compressedPayload) > len(data) { + compressedPayload = append(blankHeader, data...) + payloadLen = 0 + } + + // add header and send over the wire + err = cw.writeToNetwork(compressedPayload, payloadLen) + if err != nil { + return 0, err + } + + return totalBytes, nil +} + +func (cw *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error { + comprLength := len(data) - 7 + + // compression header + data[0] = byte(0xff & comprLength) + data[1] = byte(0xff & (comprLength >> 8)) + data[2] = byte(0xff & (comprLength >> 16)) + + data[3] = cw.mc.compressionSequence + + // this value is never greater than maxPayloadLength + data[4] = byte(0xff & uncomprLength) + data[5] = byte(0xff & (uncomprLength >> 8)) + data[6] = byte(0xff & (uncomprLength >> 16)) + + if _, err := cw.connWriter.Write(data); err != nil { + return err + } + + cw.mc.compressionSequence++ + return nil +} diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 000000000..0e98599c3 --- /dev/null +++ b/compress_test.go @@ -0,0 +1,221 @@ +package mysql + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "testing" +) + +func makeRandByteSlice(size int) []byte { + randBytes := make([]byte, size) + rand.Read(randBytes) + return randBytes +} + +func newMockConn() *mysqlConn { + newConn := &mysqlConn{} + return newConn +} + +type mockBuf struct { + reader io.Reader +} + +func newMockBuf(reader io.Reader) *mockBuf { + return &mockBuf{ + reader: reader, + } +} + +func (mb *mockBuf) readNext(need int) ([]byte, error) { + + data := make([]byte, need) + _, err := mb.reader.Read(data) + if err != nil { + return nil, err + } + return data, nil +} + +func (mb *mockBuf) reuseBuffer(length int) []byte { + return make([]byte, length) //just give them a new buffer +} + +// compressHelper compresses uncompressedPacket and checks state variables +func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { + // get status variables + + cs := mc.compressionSequence + + var b bytes.Buffer + connWriter := &b + + cw := newCompressedWriter(connWriter, mc) + + n, err := cw.Write(uncompressedPacket) + + if err != nil { + t.Fatal(err.Error()) + } + + if n != len(uncompressedPacket) { + t.Fatal(fmt.Sprintf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n)) + } + + if len(uncompressedPacket) > 0 { + + if mc.compressionSequence != (cs + 1) { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)) + } + + } else { + if mc.compressionSequence != cs { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressionSequence)) + } + } + + return b.Bytes() +} + +// roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables +func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { + compressed := compressHelper(t, cSend, uncompressedPacket) + return uncompressHelper(t, cReceive, compressed, len(uncompressedPacket)) +} + +// uncompressHelper uncompresses compressedPacket and checks state variables +func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte { + // get status variables + cs := mc.compressionSequence + + // mocking out buf variable + mockConnReader := bytes.NewReader(compressedPacket) + mockBuf := newMockBuf(mockConnReader) + + cr := newCompressedReader(mockBuf, mc) + + uncompressedPacket, err := cr.readNext(expSize) + if err != nil { + if err != io.EOF { + t.Fatal(fmt.Sprintf("non-nil/non-EOF error when reading contents: %s", err.Error())) + } + } + + if expSize > 0 { + if mc.compressionSequence != (cs + 1) { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)) + } + } else { + if mc.compressionSequence != cs { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressionSequence)) + } + } + return uncompressedPacket +} + +// TestCompressedReaderThenWriter tests reader and writer seperately. +func TestCompressedReaderThenWriter(t *testing.T) { + + makeTestUncompressedPacket := func(size int) []byte { + uncompressedHeader := make([]byte, 4) + uncompressedHeader[0] = byte(size) + uncompressedHeader[1] = byte(size >> 8) + uncompressedHeader[2] = byte(size >> 16) + + payload := make([]byte, size) + for i := range payload { + payload[i] = 'b' + } + + uncompressedPacket := append(uncompressedHeader, payload...) + return uncompressedPacket + } + + tests := []struct { + compressed []byte + uncompressed []byte + desc string + }{ + {compressed: []byte{5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 'a'}, + uncompressed: []byte{1, 0, 0, 0, 'a'}, + desc: "a"}, + {compressed: []byte{10, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, + uncompressed: []byte{6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, + desc: "golang"}, + {compressed: []byte{19, 0, 0, 0, 104, 0, 0, 120, 156, 74, 97, 96, 96, 72, 162, 3, 0, 4, 0, 0, 255, 255, 182, 165, 38, 173}, + uncompressed: makeTestUncompressedPacket(100), + desc: "100 bytes letter b"}, + {compressed: []byte{63, 0, 0, 0, 236, 128, 0, 120, 156, 236, 192, 129, 0, 0, 0, 8, 3, 176, 179, 70, 18, 110, 24, 129, 124, 187, 77, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 168, 241, 1, 0, 0, 255, 255, 42, 107, 93, 24}, + uncompressed: makeTestUncompressedPacket(33000), + desc: "33000 bytes letter b"}, + } + + for _, test := range tests { + s := fmt.Sprintf("Test compress uncompress with %s", test.desc) + + // test uncompression only + c := newMockConn() + uncompressed := uncompressHelper(t, c, test.compressed, len(test.uncompressed)) + if bytes.Compare(uncompressed, test.uncompressed) != 0 { + t.Fatal(fmt.Sprintf("%s: uncompression failed", s)) + } + + // test compression only + c = newMockConn() + compressed := compressHelper(t, c, test.uncompressed) + if bytes.Compare(compressed, test.compressed) != 0 { + t.Fatal(fmt.Sprintf("%s: compression failed", s)) + } + } +} + +// TestRoundtrip tests two connections, where one is reading and the other is writing +func TestRoundtrip(t *testing.T) { + + tests := []struct { + uncompressed []byte + desc string + }{ + {uncompressed: []byte("a"), + desc: "a"}, + {uncompressed: []byte{0}, + desc: "0 byte"}, + {uncompressed: []byte("hello world"), + desc: "hello world"}, + {uncompressed: make([]byte, 100), + desc: "100 bytes"}, + {uncompressed: make([]byte, 32768), + desc: "32768 bytes"}, + {uncompressed: make([]byte, 330000), + desc: "33000 bytes"}, + {uncompressed: make([]byte, 0), + desc: "nothing"}, + {uncompressed: makeRandByteSlice(10), + desc: "10 rand bytes", + }, + {uncompressed: makeRandByteSlice(100), + desc: "100 rand bytes", + }, + {uncompressed: makeRandByteSlice(32768), + desc: "32768 rand bytes", + }, + {uncompressed: makeRandByteSlice(33000), + desc: "33000 rand bytes", + }, + } + + cSend := newMockConn() + + cReceive := newMockConn() + + for _, test := range tests { + s := fmt.Sprintf("Test roundtrip with %s", test.desc) + + uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) + if bytes.Compare(uncompressed, test.uncompressed) != 0 { + t.Fatal(fmt.Sprintf("%s: roundtrip failed", s)) + } + } +} diff --git a/connection.go b/connection.go index e57061412..c17e62ca8 100644 --- a/connection.go +++ b/connection.go @@ -28,18 +28,20 @@ type mysqlContext interface { } type mysqlConn struct { - buf buffer - netConn net.Conn - affectedRows uint64 - insertId uint64 - cfg *Config - maxAllowedPacket int - maxWriteSize int - writeTimeout time.Duration - flags clientFlag - status statusFlag - sequence uint8 - parseTime bool + netConn net.Conn + reader packetReader + writer io.Writer + affectedRows uint64 + insertId uint64 + cfg *Config + maxAllowedPacket int + maxWriteSize int + writeTimeout time.Duration + flags clientFlag + status statusFlag + sequence uint8 + compressionSequence uint8 + parseTime bool // for context support (Go 1.8+) watching bool @@ -50,6 +52,11 @@ type mysqlConn struct { closed atomicBool // set when conn is closed, before closech is closed } +type packetReader interface { + readNext(need int) ([]byte, error) + reuseBuffer(length int) []byte +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { for param, val := range mc.cfg.Params { @@ -190,7 +197,8 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - buf := mc.buf.takeCompleteBuffer() + buf := mc.reader.reuseBuffer(-1) + if buf == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) diff --git a/connection_test.go b/connection_test.go index 65325f101..ac750b574 100644 --- a/connection_test.go +++ b/connection_test.go @@ -15,12 +15,13 @@ import ( func TestInterpolateParams(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, } + buf := newBuffer(nil) + mc.reader = &buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) if err != nil { @@ -35,12 +36,13 @@ func TestInterpolateParams(t *testing.T) { func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, } + buf := newBuffer(nil) + mc.reader = &buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) if err != driver.ErrSkip { @@ -52,13 +54,15 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { // https://github.com/go-sql-driver/mysql/pull/490 func TestInterpolateParamsPlaceholderInString(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, } + buf := newBuffer(nil) + mc.reader = &buf + q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` if err != driver.ErrSkip { diff --git a/driver.go b/driver.go index d42ce7a3d..ddde423da 100644 --- a/driver.go +++ b/driver.go @@ -91,12 +91,16 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { s.startWatcher() } - mc.buf = newBuffer(mc.netConn) + buf := newBuffer(mc.netConn) // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout + buf.timeout = mc.cfg.ReadTimeout mc.writeTimeout = mc.cfg.WriteTimeout + // packet reader and writer in handshake are never compressed + mc.reader = &buf + mc.writer = mc.netConn + // Reading Handshake Initialization Packet cipher, err := mc.readInitPacket() if err != nil { @@ -119,6 +123,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return nil, err } + if mc.cfg.Compress { + mc.reader = newCompressedReader(&buf, mc) + mc.writer = newCompressedWriter(mc.writer, mc) + } + if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { diff --git a/dsn.go b/dsn.go index 47eab6945..82d15a8fb 100644 --- a/dsn.go +++ b/dsn.go @@ -53,6 +53,7 @@ type Config struct { AllowOldPasswords bool // Allows the old insecure password method ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names + Compress bool // Compress packets InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time @@ -461,7 +462,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": - return errors.New("compression not implemented yet") + var isBool bool + cfg.Compress, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } // Enable client side placeholder substitution case "interpolateParams": diff --git a/packets.go b/packets.go index afc3fcc46..ceab20ef8 100644 --- a/packets.go +++ b/packets.go @@ -28,7 +28,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte for { // read packet header - data, err := mc.buf.readNext(4) + data, err := mc.reader.readNext(4) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -64,7 +64,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = mc.buf.readNext(pktLen) + data, err = mc.reader.readNext(pktLen) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -118,7 +118,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { } } - n, err := mc.netConn.Write(data[:4+size]) + n, err := mc.writer.Write(data[:4+size]) if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize { @@ -258,6 +258,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientFoundRows } + if mc.cfg.Compress { + clientFlags |= clientCompress + } + // To enable TLS / SSL if mc.cfg.tls != nil { clientFlags |= clientSSL @@ -279,7 +283,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) + data := mc.reader.reuseBuffer(pktLen + 4) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -322,7 +327,12 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { return err } mc.netConn = tlsConn - mc.buf.nc = tlsConn + nc := tlsConn + + newBuf := newBuffer(nc) + mc.reader = &newBuf + + mc.writer = mc.netConn } // Filler [23 bytes] (all 0x00) @@ -367,7 +377,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) + data := mc.reader.reuseBuffer(4 + pktLen) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -386,7 +397,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { func (mc *mysqlConn) writeClearAuthPacket() error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) + data := mc.reader.reuseBuffer(4 + pktLen) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -409,7 +421,8 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) - data := mc.buf.takeSmallBuffer(4 + pktLen) + data := mc.reader.reuseBuffer(4 + pktLen) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -429,8 +442,10 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 + mc.compressionSequence = 0 + + data := mc.reader.reuseBuffer(4 + 1) - data := mc.buf.takeSmallBuffer(4 + 1) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -447,9 +462,11 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 + mc.compressionSequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) + data := mc.reader.reuseBuffer(pktLen + 4) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -469,8 +486,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 + mc.compressionSequence = 0 + + data := mc.reader.reuseBuffer(4 + 1 + 4) - data := mc.buf.takeSmallBuffer(4 + 1 + 4) if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -857,7 +876,8 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { - maxLen := stmt.mc.maxAllowedPacket - 1 + mc := stmt.mc + maxLen := mc.maxAllowedPacket - 1 pktLen := maxLen // After the header (bytes 0-3) follows before the data: @@ -878,7 +898,8 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.sequence = 0 + mc.sequence = 0 + mc.compressionSequence = 0 // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -893,7 +914,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { data[10] = byte(paramID >> 8) // Send CMD packet - err := stmt.mc.writePacket(data[:4+pktLen]) + err := mc.writePacket(data[:4+pktLen]) if err == nil { data = data[pktLen-dataOffset:] continue @@ -903,7 +924,8 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { } // Reset Packet Sequence - stmt.mc.sequence = 0 + mc.sequence = 0 + mc.compressionSequence = 0 return nil } @@ -929,13 +951,15 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Reset packet-sequence mc.sequence = 0 + mc.compressionSequence = 0 var data []byte if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data = mc.reader.reuseBuffer(minPktLen) + } else { - data = mc.buf.takeCompleteBuffer() + data = mc.reader.reuseBuffer(-1) } if data == nil { // can not take the buffer. Something must be wrong with the connection @@ -1114,7 +1138,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - mc.buf.buf = data + + bufBuf := mc.reader.reuseBuffer(-1) + bufBuf = data + fmt.Println(bufBuf) //dont know how to make it compile w/o some op here on bufBuf } pos += len(paramValues) diff --git a/packets_test.go b/packets_test.go index 2f8207511..2f98d8c24 100644 --- a/packets_test.go +++ b/packets_test.go @@ -89,8 +89,9 @@ var _ net.Conn = new(mockConn) func TestReadPacketSingleByte(t *testing.T) { conn := new(mockConn) + buf := newBuffer(conn) mc := &mysqlConn{ - buf: newBuffer(conn), + reader: &buf, } conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -109,8 +110,9 @@ func TestReadPacketSingleByte(t *testing.T) { func TestReadPacketWrongSequenceID(t *testing.T) { conn := new(mockConn) + buf := newBuffer(conn) mc := &mysqlConn{ - buf: newBuffer(conn), + reader: &buf, } // too low sequence id @@ -125,7 +127,8 @@ func TestReadPacketWrongSequenceID(t *testing.T) { // reset conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + newBuf := newBuffer(conn) + mc.reader = &newBuf // too high sequence id conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} @@ -137,8 +140,9 @@ func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) + buf := newBuffer(conn) mc := &mysqlConn{ - buf: newBuffer(conn), + reader: &buf, } data := make([]byte, maxPacketSize*2+4*3) @@ -242,8 +246,9 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) + buf := newBuffer(conn) mc := &mysqlConn{ - buf: newBuffer(conn), + reader: &buf, closech: make(chan struct{}), } @@ -258,7 +263,8 @@ func TestReadPacketFail(t *testing.T) { // reset conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + newBuf := newBuffer(conn) + mc.reader = &newBuf // fail to read header conn.closed = true @@ -271,7 +277,8 @@ func TestReadPacketFail(t *testing.T) { conn.closed = false conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + newBuf = newBuffer(conn) + mc.reader = &newBuf // fail to read body conn.maxReads = 1