From e6c682cec452c7b20319f048fe38f1245015845f Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Mon, 31 Jul 2017 16:43:52 -0400 Subject: [PATCH 01/33] packets: implemented compression protocol --- AUTHORS | 1 + benchmark_test.go | 1 + compress.go | 217 ++++++++++++++++++++++++++++++++++++++++++++ compress_test.go | 220 +++++++++++++++++++++++++++++++++++++++++++++ connection.go | 29 +++--- connection_test.go | 4 + driver.go | 9 ++ dsn.go | 7 +- packets.go | 15 +++- packets_test.go | 6 ++ 10 files changed, 492 insertions(+), 17 deletions(-) create mode 100644 compress.go create mode 100644 compress_test.go diff --git a/AUTHORS b/AUTHORS index 5526e3e90..f137bfcc7 100644 --- a/AUTHORS +++ b/AUTHORS @@ -15,6 +15,7 @@ Aaron Hopkins Achille Roussel Arne Hormann Asta Xie +B Lamarche Bulat Gaifullin Carlos Nieto Chris Moos diff --git a/benchmark_test.go b/benchmark_test.go index 7da833a2a..460553e03 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -224,6 +224,7 @@ func BenchmarkInterpolation(b *testing.B) { maxWriteSize: maxPacketSize - 1, buf: newBuffer(nil), } + mc.reader = &mc.buf args := []driver.Value{ int64(42424242), diff --git a/compress.go b/compress.go new file mode 100644 index 000000000..2349aa13b --- /dev/null +++ b/compress.go @@ -0,0 +1,217 @@ +package mysql + +import ( + "bytes" + "compress/zlib" + "io" +) + +const ( + minCompressLength = 50 +) + +type packetReader interface { + readNext(need int) ([]byte, error) +} + +type compressedReader struct { + buf packetReader + bytesBuf []byte + mc *mysqlConn +} + +type compressedWriter struct { + connWriter io.Writer + mc *mysqlConn +} + +func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { + return &compressedReader{ + buf: buf, + bytesBuf: make([]byte, 0), + mc: mc, + } +} + +func NewCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { + return &compressedWriter{ + connWriter: connWriter, + mc: mc, + } +} + +func (cr *compressedReader) readNext(need int) ([]byte, error) { + for len(cr.bytesBuf) < need { + err := cr.uncompressPacket() + if err != nil { + return nil, err + } + } + + data := make([]byte, need) + + copy(data, cr.bytesBuf[:len(data)]) + + cr.bytesBuf = cr.bytesBuf[len(data):] + + return data, nil +} + +func (cr *compressedReader) uncompressPacket() error { + header, err := cr.buf.readNext(7) // size of compressed header + + if err != nil { + return err + } + + // compressed header structure + comprLength := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) + uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16) + compressionSequence := uint8(header[3]) + + if compressionSequence != cr.mc.compressionSequence { + return ErrPktSync + } + + cr.mc.compressionSequence++ + + comprData, err := cr.buf.readNext(comprLength) + if err != nil { + return err + } + + // if payload is uncompressed, its length will be specified as zero, and its + // true length is contained in comprLength + if uncompressedLength == 0 { + cr.bytesBuf = append(cr.bytesBuf, comprData...) + return nil + } + + // write comprData to a bytes.buffer, then read it using zlib into data + var b bytes.Buffer + b.Write(comprData) + r, err := zlib.NewReader(&b) + + if r != nil { + defer r.Close() + } + + if err != nil { + return err + } + + data := make([]byte, uncompressedLength) + lenRead := 0 + + // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate + for lenRead < uncompressedLength { + + tmp := data[lenRead:] + + n, err := r.Read(tmp) + lenRead += n + + if err == io.EOF { + if lenRead < uncompressedLength { + return io.ErrUnexpectedEOF + } + break + } + + if err != nil { + return err + } + } + + cr.bytesBuf = append(cr.bytesBuf, data...) + + return nil +} + +func (cw *compressedWriter) Write(data []byte) (int, error) { + // when asked to write an empty packet, do nothing + if len(data) == 0 { + return 0, nil + } + totalBytes := len(data) + + length := len(data) - 4 + + maxPayloadLength := maxPacketSize - 4 + + for length >= maxPayloadLength { + // cut off a slice of size max payload length + dataSmall := data[:maxPayloadLength] + lenSmall := len(dataSmall) + + var b bytes.Buffer + writer := zlib.NewWriter(&b) + _, err := writer.Write(dataSmall) + writer.Close() + if err != nil { + return 0, err + } + + err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall) + if err != nil { + return 0, err + } + + length -= maxPayloadLength + data = data[maxPayloadLength:] + } + + lenSmall := len(data) + + // do not compress if packet is too small + if lenSmall < minCompressLength { + err := cw.writeComprPacketToNetwork(data, 0) + if err != nil { + return 0, err + } + + return totalBytes, nil + } + + var b bytes.Buffer + writer := zlib.NewWriter(&b) + + _, err := writer.Write(data) + writer.Close() + + if err != nil { + return 0, err + } + + err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall) + + if err != nil { + return 0, err + } + return totalBytes, nil +} + +func (cw *compressedWriter) writeComprPacketToNetwork(data []byte, uncomprLength int) error { + data = append([]byte{0, 0, 0, 0, 0, 0, 0}, data...) + + comprLength := len(data) - 7 + + // compression header + data[0] = byte(0xff & comprLength) + data[1] = byte(0xff & (comprLength >> 8)) + data[2] = byte(0xff & (comprLength >> 16)) + + data[3] = cw.mc.compressionSequence + + //this value is never greater than maxPayloadLength + data[4] = byte(0xff & uncomprLength) + data[5] = byte(0xff & (uncomprLength >> 8)) + data[6] = byte(0xff & (uncomprLength >> 16)) + + if _, err := cw.connWriter.Write(data); err != nil { + return err + } + + cw.mc.compressionSequence++ + return nil +} diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 000000000..c626ff3ee --- /dev/null +++ b/compress_test.go @@ -0,0 +1,220 @@ +package mysql + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "testing" +) + +func makeRandByteSlice(size int) []byte { + randBytes := make([]byte, size) + rand.Read(randBytes) + return randBytes +} + +func newMockConn() *mysqlConn { + newConn := &mysqlConn{} + return newConn +} + +type mockBuf struct { + reader io.Reader +} + +func newMockBuf(reader io.Reader) *mockBuf { + return &mockBuf{ + reader: reader, + } +} + +func (mb *mockBuf) readNext(need int) ([]byte, error) { + + data := make([]byte, need) + _, err := mb.reader.Read(data) + if err != nil { + return nil, err + } + return data, nil +} + +// compressHelper compresses uncompressedPacket and checks state variables +func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { + // get status variables + + cs := mc.compressionSequence + + var b bytes.Buffer + connWriter := &b + + cw := NewCompressedWriter(connWriter, mc) + + n, err := cw.Write(uncompressedPacket) + + if err != nil { + t.Fatal(err.Error()) + } + + if n != len(uncompressedPacket) { + t.Fatal(fmt.Sprintf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n)) + } + + if len(uncompressedPacket) > 0 { + + if mc.compressionSequence != (cs + 1) { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)) + } + + } else { + if mc.compressionSequence != cs { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressionSequence)) + } + } + + return b.Bytes() +} + +// roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables +func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { + compressed := compressHelper(t, cSend, uncompressedPacket) + return uncompressHelper(t, cReceive, compressed, len(uncompressedPacket)) +} + +// uncompressHelper uncompresses compressedPacket and checks state variables +func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte { + // get status variables + cs := mc.compressionSequence + + // mocking out buf variable + mockConnReader := bytes.NewReader(compressedPacket) + mockBuf := newMockBuf(mockConnReader) + + cr := NewCompressedReader(mockBuf, mc) + + uncompressedPacket, err := cr.readNext(expSize) + if err != nil { + if err != io.EOF { + t.Fatal(fmt.Sprintf("non-nil/non-EOF error when reading contents: %s", err.Error())) + } + } + + if expSize > 0 { + if mc.compressionSequence != (cs + 1) { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressionSequence)) + } + } else { + if mc.compressionSequence != cs { + t.Fatal(fmt.Sprintf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressionSequence)) + } + } + return uncompressedPacket +} + +// TestCompressedReaderThenWriter tests reader and writer seperately. +func TestCompressedReaderThenWriter(t *testing.T) { + + makeTestUncompressedPacket := func(size int) []byte { + uncompressedHeader := make([]byte, 4) + uncompressedHeader[0] = byte(size) + uncompressedHeader[1] = byte(size >> 8) + uncompressedHeader[2] = byte(size >> 16) + + payload := make([]byte, size) + for i := range payload { + payload[i] = 'b' + } + + uncompressedPacket := append(uncompressedHeader, payload...) + return uncompressedPacket + } + + tests := []struct { + compressed []byte + uncompressed []byte + desc string + }{ + {compressed: []byte{5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 'a'}, + uncompressed: []byte{1, 0, 0, 0, 'a'}, + desc: "a"}, + {compressed: []byte{10, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, + uncompressed: []byte{6, 0, 0, 0, 'g', 'o', 'l', 'a', 'n', 'g'}, + desc: "golang"}, + {compressed: []byte{19, 0, 0, 0, 104, 0, 0, 120, 156, 74, 97, 96, 96, 72, 162, 3, 0, 4, 0, 0, 255, 255, 182, 165, 38, 173}, + uncompressed: makeTestUncompressedPacket(100), + desc: "100 bytes letter b"}, + {compressed: []byte{63, 0, 0, 0, 236, 128, 0, 120, 156, 236, 192, 129, 0, 0, 0, 8, 3, 176, 179, 70, 18, 110, 24, 129, 124, 187, 77, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 168, 241, 1, 0, 0, 255, 255, 42, 107, 93, 24}, + uncompressed: makeTestUncompressedPacket(33000), + desc: "33000 bytes letter b"}, + } + + for _, test := range tests { + s := fmt.Sprintf("Test compress uncompress with %s", test.desc) + + // test uncompression only + c := newMockConn() + uncompressed := uncompressHelper(t, c, test.compressed, len(test.uncompressed)) + if bytes.Compare(uncompressed, test.uncompressed) != 0 { + t.Fatal(fmt.Sprintf("%s: uncompression failed", s)) + } + + // test compression only + c = newMockConn() + compressed := compressHelper(t, c, test.uncompressed) + if bytes.Compare(compressed, test.compressed) != 0 { + t.Fatal(fmt.Sprintf("%s: compression failed", s)) + } + } +} + +// TestRoundtrip tests two connections, where one is reading and the other is writing +func TestRoundtrip(t *testing.T) { + + tests := []struct { + uncompressed []byte + desc string + }{ + {uncompressed: []byte("a"), + desc: "a"}, + {uncompressed: []byte{0}, + desc: "0 byte"}, + {uncompressed: []byte("hello world"), + desc: "hello world"}, + {uncompressed: make([]byte, 100), + desc: "100 bytes"}, + {uncompressed: make([]byte, 32768), + desc: "32768 bytes"}, + {uncompressed: make([]byte, 330000), + desc: "33000 bytes"}, + {uncompressed: make([]byte, 0), + desc: "nothing"}, + {uncompressed: makeRandByteSlice(10), + desc: "10 rand bytes", + }, + {uncompressed: makeRandByteSlice(100), + desc: "100 rand bytes", + }, + {uncompressed: makeRandByteSlice(32768), + desc: "32768 rand bytes", + }, + {uncompressed: makeRandByteSlice(33000), + desc: "33000 rand bytes", + }, + } + + cSend := newMockConn() + + cReceive := newMockConn() + + for _, test := range tests { + s := fmt.Sprintf("Test roundtrip with %s", test.desc) + //t.Run(s, func(t *testing.T) { + + uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) + if bytes.Compare(uncompressed, test.uncompressed) != 0 { + t.Fatal(fmt.Sprintf("%s: roundtrip failed", s)) + } + + //}) + } +} diff --git a/connection.go b/connection.go index 2630f5211..663f94b59 100644 --- a/connection.go +++ b/connection.go @@ -28,19 +28,22 @@ type mysqlContext interface { } type mysqlConn struct { - buf buffer - netConn net.Conn - affectedRows uint64 - insertId uint64 - cfg *Config - maxAllowedPacket int - maxWriteSize int - writeTimeout time.Duration - flags clientFlag - status statusFlag - sequence uint8 - parseTime bool - strict bool + buf buffer + netConn net.Conn + affectedRows uint64 + insertId uint64 + cfg *Config + maxAllowedPacket int + maxWriteSize int + writeTimeout time.Duration + flags clientFlag + status statusFlag + sequence uint8 + compressionSequence uint8 + parseTime bool + strict bool + reader packetReader + writer io.Writer // for context support (Go 1.8+) watching bool diff --git a/connection_test.go b/connection_test.go index 65325f101..187c76116 100644 --- a/connection_test.go +++ b/connection_test.go @@ -21,6 +21,7 @@ func TestInterpolateParams(t *testing.T) { InterpolateParams: true, }, } + mc.reader = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) if err != nil { @@ -41,6 +42,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { InterpolateParams: true, }, } + mc.reader = &mc.buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) if err != driver.ErrSkip { @@ -59,6 +61,8 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { }, } + mc.reader = &mc.buf + q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` if err != driver.ErrSkip { diff --git a/driver.go b/driver.go index c341b6680..691326218 100644 --- a/driver.go +++ b/driver.go @@ -94,6 +94,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.buf = newBuffer(mc.netConn) + // packet reader and writer in handshake are never compressed + mc.reader = &mc.buf + mc.writer = mc.netConn + // Set I/O timeouts mc.buf.timeout = mc.cfg.ReadTimeout mc.writeTimeout = mc.cfg.WriteTimeout @@ -120,6 +124,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return nil, err } + if mc.cfg.Compression { + mc.reader = NewCompressedReader(&mc.buf, mc) + mc.writer = NewCompressedWriter(mc.writer, mc) + } + if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { diff --git a/dsn.go b/dsn.go index ab2fdfc6a..626fde365 100644 --- a/dsn.go +++ b/dsn.go @@ -56,6 +56,7 @@ type Config struct { ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections Strict bool // Return warnings as errors + Compression bool // Compress packets } // FormatDSN formats the given Config into a DSN string which can be passed to @@ -445,7 +446,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": - return errors.New("compression not implemented yet") + var isBool bool + cfg.Compression, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } // Enable client side placeholder substitution case "interpolateParams": diff --git a/packets.go b/packets.go index 9715067c4..fbdeb6897 100644 --- a/packets.go +++ b/packets.go @@ -28,7 +28,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte for { // read packet header - data, err := mc.buf.readNext(4) + data, err := mc.reader.readNext(4) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -64,7 +64,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = mc.buf.readNext(pktLen) + data, err = mc.reader.readNext(pktLen) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -118,7 +118,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { } } - n, err := mc.netConn.Write(data[:4+size]) + n, err := mc.writer.Write(data[:4+size]) if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize { @@ -249,6 +249,10 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientFoundRows } + if mc.cfg.Compression { + clientFlags |= clientCompress + } + // To enable TLS / SSL if mc.cfg.tls != nil { clientFlags |= clientSSL @@ -314,6 +318,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } mc.netConn = tlsConn mc.buf.nc = tlsConn + + mc.writer = mc.netConn } // Filler [23 bytes] (all 0x00) @@ -416,6 +422,7 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 + mc.compressionSequence = 0 data := mc.buf.takeSmallBuffer(4 + 1) if data == nil { @@ -434,6 +441,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 + mc.compressionSequence = 0 pktLen := 1 + len(arg) data := mc.buf.takeBuffer(pktLen + 4) @@ -456,6 +464,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 + mc.compressionSequence = 0 data := mc.buf.takeSmallBuffer(4 + 1 + 4) if data == nil { diff --git a/packets_test.go b/packets_test.go index 31c892d85..53752e3b8 100644 --- a/packets_test.go +++ b/packets_test.go @@ -94,6 +94,8 @@ func TestReadPacketSingleByte(t *testing.T) { buf: newBuffer(conn), } + mc.reader = &mc.buf + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 packet, err := mc.readPacket() @@ -113,6 +115,7 @@ func TestReadPacketWrongSequenceID(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), } + mc.reader = &mc.buf // too low sequence id conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -142,6 +145,8 @@ func TestReadPacketSplit(t *testing.T) { buf: newBuffer(conn), } + mc.reader = &mc.buf + data := make([]byte, maxPacketSize*2+4*3) const pkt2ofs = maxPacketSize + 4 const pkt3ofs = 2 * (maxPacketSize + 4) @@ -247,6 +252,7 @@ func TestReadPacketFail(t *testing.T) { buf: newBuffer(conn), closech: make(chan struct{}), } + mc.reader = &mc.buf // illegal empty (stand-alone) packet conn.data = []byte{0x00, 0x00, 0x00, 0x00} From 77f679299d7320e2ab462df5b5d9e6913a43ca47 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Wed, 16 Aug 2017 13:55:33 -0400 Subject: [PATCH 02/33] packets: implemented compression protocol CR changes --- benchmark_go18_test.go | 4 +-- benchmark_test.go | 19 +++++++++-- compress.go | 72 ++++++++++++++++++++++++++++++------------ packets.go | 3 ++ 4 files changed, 72 insertions(+), 26 deletions(-) diff --git a/benchmark_go18_test.go b/benchmark_go18_test.go index d6a7e9d6e..5522ab9cf 100644 --- a/benchmark_go18_test.go +++ b/benchmark_go18_test.go @@ -42,7 +42,7 @@ func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkQueryContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -78,7 +78,7 @@ func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkExecContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, diff --git a/benchmark_test.go b/benchmark_test.go index 460553e03..2d690906f 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -43,9 +43,13 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { return stmt } -func initDB(b *testing.B, queries ...string) *sql.DB { +func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + comprStr := "" + if useCompression { + comprStr = "&compress=1" + } + db := tb.checkDB(sql.Open("mysql", dsn+comprStr)) for _, query := range queries { if _, err := db.Exec(query); err != nil { if w, ok := err.(MySQLWarnings); ok { @@ -61,10 +65,19 @@ func initDB(b *testing.B, queries ...string) *sql.DB { const concurrencyLevel = 10 func BenchmarkQuery(b *testing.B) { + benchmarkQueryHelper(b, false) +} + +func BenchmarkQueryCompression(b *testing.B) { + benchmarkQueryHelper(b, true) +} + +func benchmarkQueryHelper(b *testing.B, compr bool) { + tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := initDB(b, + db := initDB(b, compr, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, diff --git a/compress.go b/compress.go index 2349aa13b..14a131f4d 100644 --- a/compress.go +++ b/compress.go @@ -18,6 +18,8 @@ type compressedReader struct { buf packetReader bytesBuf []byte mc *mysqlConn + br *bytes.Reader + zr io.ReadCloser } type compressedWriter struct { @@ -48,12 +50,8 @@ func (cr *compressedReader) readNext(need int) ([]byte, error) { } } - data := make([]byte, need) - - copy(data, cr.bytesBuf[:len(data)]) - - cr.bytesBuf = cr.bytesBuf[len(data):] - + data := cr.bytesBuf[:need] + cr.bytesBuf = cr.bytesBuf[need:] return data, nil } @@ -88,27 +86,43 @@ func (cr *compressedReader) uncompressPacket() error { } // write comprData to a bytes.buffer, then read it using zlib into data - var b bytes.Buffer - b.Write(comprData) - r, err := zlib.NewReader(&b) + if cr.br == nil { + cr.br = bytes.NewReader(comprData) + } else { + cr.br.Reset(comprData) + } + + resetter, ok := cr.zr.(zlib.Resetter) - if r != nil { - defer r.Close() + if ok { + err := resetter.Reset(cr.br, []byte{}) + if err != nil { + return err + } + } else { + cr.zr, err = zlib.NewReader(cr.br) + if err != nil { + return err + } } - if err != nil { - return err + defer cr.zr.Close() + + //use existing capacity in bytesBuf if possible + offset := len(cr.bytesBuf) + if cap(cr.bytesBuf)-offset < uncompressedLength { + old := cr.bytesBuf + cr.bytesBuf = make([]byte, offset, offset+uncompressedLength) + copy(cr.bytesBuf, old) } - data := make([]byte, uncompressedLength) + data := cr.bytesBuf[offset : offset+uncompressedLength] + lenRead := 0 // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate for lenRead < uncompressedLength { - - tmp := data[lenRead:] - - n, err := r.Read(tmp) + n, err := cr.zr.Read(data[lenRead:]) lenRead += n if err == io.EOF { @@ -152,7 +166,15 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { return 0, err } - err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall) + // if compression expands the payload, do not compress + useData := b.Bytes() + + if len(useData) > len(dataSmall) { + useData = dataSmall + lenSmall = 0 + } + + err = cw.writeComprPacketToNetwork(useData, lenSmall) if err != nil { return 0, err } @@ -163,7 +185,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { lenSmall := len(data) - // do not compress if packet is too small + // do not attempt compression if packet is too small if lenSmall < minCompressLength { err := cw.writeComprPacketToNetwork(data, 0) if err != nil { @@ -183,7 +205,15 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { return 0, err } - err = cw.writeComprPacketToNetwork(b.Bytes(), lenSmall) + // if compression expands the payload, do not compress + useData := b.Bytes() + + if len(useData) > len(data) { + useData = data + lenSmall = 0 + } + + err = cw.writeComprPacketToNetwork(useData, lenSmall) if err != nil { return 0, err diff --git a/packets.go b/packets.go index fbdeb6897..f8ff6a298 100644 --- a/packets.go +++ b/packets.go @@ -881,6 +881,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { } stmt.mc.sequence = 0 + stmt.mc.compressionSequence = 0 // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -906,6 +907,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // Reset Packet Sequence stmt.mc.sequence = 0 + stmt.mc.compressionSequence = 0 return nil } @@ -925,6 +927,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Reset packet-sequence mc.sequence = 0 + mc.compressionSequence = 0 var data []byte From a0cf94b33baca6fd00a0d761192b6f2b4d8fd103 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Wed, 16 Aug 2017 14:32:26 -0400 Subject: [PATCH 03/33] packets: implemented compression protocol: remove bytes.Reset for backwards compatibility --- compress.go | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/compress.go b/compress.go index 14a131f4d..17bf1560b 100644 --- a/compress.go +++ b/compress.go @@ -18,13 +18,13 @@ type compressedReader struct { buf packetReader bytesBuf []byte mc *mysqlConn - br *bytes.Reader zr io.ReadCloser } type compressedWriter struct { connWriter io.Writer mc *mysqlConn + header []byte } func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { @@ -39,6 +39,7 @@ func NewCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter return &compressedWriter{ connWriter: connWriter, mc: mc, + header: []byte{0, 0, 0, 0, 0, 0, 0}, } } @@ -86,21 +87,17 @@ func (cr *compressedReader) uncompressPacket() error { } // write comprData to a bytes.buffer, then read it using zlib into data - if cr.br == nil { - cr.br = bytes.NewReader(comprData) - } else { - cr.br.Reset(comprData) - } + br := bytes.NewReader(comprData) resetter, ok := cr.zr.(zlib.Resetter) if ok { - err := resetter.Reset(cr.br, []byte{}) + err := resetter.Reset(br, []byte{}) if err != nil { return err } } else { - cr.zr, err = zlib.NewReader(cr.br) + cr.zr, err = zlib.NewReader(br) if err != nil { return err } @@ -222,7 +219,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } func (cw *compressedWriter) writeComprPacketToNetwork(data []byte, uncomprLength int) error { - data = append([]byte{0, 0, 0, 0, 0, 0, 0}, data...) + data = append(cw.header, data...) comprLength := len(data) - 7 From d0ea1a418dc8ab516f245ad96b63157e852d8c9d Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Fri, 18 Aug 2017 16:24:53 -0400 Subject: [PATCH 04/33] reading working --- compress.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/compress.go b/compress.go index 17bf1560b..2ccd8970c 100644 --- a/compress.go +++ b/compress.go @@ -10,6 +10,10 @@ const ( minCompressLength = 50 ) +var ( + blankHeader = []byte{0, 0, 0, 0, 0, 0, 0} +) + type packetReader interface { readNext(need int) ([]byte, error) } @@ -24,7 +28,7 @@ type compressedReader struct { type compressedWriter struct { connWriter io.Writer mc *mysqlConn - header []byte + zw *zlib.Writer } func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { @@ -39,7 +43,7 @@ func NewCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter return &compressedWriter{ connWriter: connWriter, mc: mc, - header: []byte{0, 0, 0, 0, 0, 0, 0}, + zw: zlib.NewWriter(new(bytes.Buffer)), } } @@ -89,18 +93,14 @@ func (cr *compressedReader) uncompressPacket() error { // write comprData to a bytes.buffer, then read it using zlib into data br := bytes.NewReader(comprData) - resetter, ok := cr.zr.(zlib.Resetter) - - if ok { - err := resetter.Reset(br, []byte{}) - if err != nil { - return err - } - } else { + if cr.zr == nil { cr.zr, err = zlib.NewReader(br) - if err != nil { - return err - } + } else { + err = cr.zr.(zlib.Resetter).Reset(br, nil) + } + + if err != nil { + return err } defer cr.zr.Close() @@ -219,7 +219,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } func (cw *compressedWriter) writeComprPacketToNetwork(data []byte, uncomprLength int) error { - data = append(cw.header, data...) + data = append(blankHeader, data...) comprLength := len(data) - 7 From 477c9f844736475945470522395fe6811cca52bf Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Fri, 18 Aug 2017 16:36:28 -0400 Subject: [PATCH 05/33] writerly changes --- compress.go | 64 ++++++++++++++++++++++++++--------------------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/compress.go b/compress.go index 2ccd8970c..96f431139 100644 --- a/compress.go +++ b/compress.go @@ -144,6 +144,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { if len(data) == 0 { return 0, nil } + totalBytes := len(data) length := len(data) - 4 @@ -151,27 +152,26 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { maxPayloadLength := maxPacketSize - 4 for length >= maxPayloadLength { - // cut off a slice of size max payload length - dataSmall := data[:maxPayloadLength] - lenSmall := len(dataSmall) - - var b bytes.Buffer - writer := zlib.NewWriter(&b) - _, err := writer.Write(dataSmall) - writer.Close() + payload := data[:maxPayloadLength] + payloadLen := len(payload) + + bytesBuf := &bytes.Buffer{} + bytesBuf.Write(blankHeader) + cw.zw.Reset(bytesBuf) + _, err := cw.zw.Write(payload) if err != nil { return 0, err } + cw.zw.Close() // if compression expands the payload, do not compress - useData := b.Bytes() - - if len(useData) > len(dataSmall) { - useData = dataSmall - lenSmall = 0 + compressedPayload := bytesBuf.Bytes() + if len(compressedPayload) > maxPayloadLength { + compressedPayload = append(blankHeader, payload...) + payloadLen = 0 } - err = cw.writeComprPacketToNetwork(useData, lenSmall) + err = cw.writeToNetwork(compressedPayload, payloadLen) if err != nil { return 0, err } @@ -180,46 +180,44 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { data = data[maxPayloadLength:] } - lenSmall := len(data) + payloadLen := len(data) // do not attempt compression if packet is too small - if lenSmall < minCompressLength { - err := cw.writeComprPacketToNetwork(data, 0) + if payloadLen < minCompressLength { + err := cw.writeToNetwork(append(blankHeader, data...), 0) if err != nil { return 0, err } - return totalBytes, nil } - var b bytes.Buffer - writer := zlib.NewWriter(&b) - - _, err := writer.Write(data) - writer.Close() - + bytesBuf := &bytes.Buffer{} + bytesBuf.Write(blankHeader) + cw.zw.Reset(bytesBuf) + _, err := cw.zw.Write(data) if err != nil { return 0, err } + cw.zw.Close() - // if compression expands the payload, do not compress - useData := b.Bytes() + compressedPayload := bytesBuf.Bytes() - if len(useData) > len(data) { - useData = data - lenSmall = 0 + if len(compressedPayload) > len(data) { + compressedPayload = append(blankHeader, data...) + payloadLen = 0 } - err = cw.writeComprPacketToNetwork(useData, lenSmall) - + // add header and send over the wire + err = cw.writeToNetwork(compressedPayload, payloadLen) if err != nil { return 0, err } + return totalBytes, nil + } -func (cw *compressedWriter) writeComprPacketToNetwork(data []byte, uncomprLength int) error { - data = append(blankHeader, data...) +func (cw *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error { comprLength := len(data) - 7 From 996ed2d17131da2ef41e46fbdbd34186d47cb700 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Sun, 8 Oct 2017 14:00:11 -0400 Subject: [PATCH 06/33] PR 649: adding compression (second code review) --- compress.go | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/compress.go b/compress.go index 96f431139..3425c72fa 100644 --- a/compress.go +++ b/compress.go @@ -10,9 +10,6 @@ const ( minCompressLength = 50 ) -var ( - blankHeader = []byte{0, 0, 0, 0, 0, 0, 0} -) type packetReader interface { readNext(need int) ([]byte, error) @@ -146,17 +143,16 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } totalBytes := len(data) - length := len(data) - 4 - maxPayloadLength := maxPacketSize - 4 + blankHeader := make([]byte, 7) for length >= maxPayloadLength { payload := data[:maxPayloadLength] payloadLen := len(payload) bytesBuf := &bytes.Buffer{} - bytesBuf.Write(blankHeader) + bytesBuf.Write(blankHeader) cw.zw.Reset(bytesBuf) _, err := cw.zw.Write(payload) if err != nil { @@ -167,7 +163,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { // if compression expands the payload, do not compress compressedPayload := bytesBuf.Bytes() if len(compressedPayload) > maxPayloadLength { - compressedPayload = append(blankHeader, payload...) + compressedPayload = append(blankHeader, payload...) payloadLen = 0 } @@ -184,7 +180,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { // do not attempt compression if packet is too small if payloadLen < minCompressLength { - err := cw.writeToNetwork(append(blankHeader, data...), 0) + err := cw.writeToNetwork(append(blankHeader, data...), 0) if err != nil { return 0, err } @@ -203,7 +199,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { compressedPayload := bytesBuf.Bytes() if len(compressedPayload) > len(data) { - compressedPayload = append(blankHeader, data...) + compressedPayload = append(blankHeader, data...) payloadLen = 0 } From f74faedaa752df4ed5f57082f67fa80b60b0f62e Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Thu, 12 Oct 2017 10:31:18 +0200 Subject: [PATCH 07/33] do not query max_allowed_packet by default (#680) --- README.md | 4 ++-- const.go | 3 ++- driver_test.go | 2 +- dsn.go | 3 ++- dsn_test.go | 50 ++++++++++++++++++++++++-------------------------- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index b5882e6c8..6a306bb30 100644 --- a/README.md +++ b/README.md @@ -232,10 +232,10 @@ Please keep in mind, that param values must be [url.QueryEscape](https://golang. ##### `maxAllowedPacket` ``` Type: decimal number -Default: 0 +Default: 4194304 ``` -Max packet size allowed in bytes. Use `maxAllowedPacket=0` to automatically fetch the `max_allowed_packet` variable from server. +Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. ##### `multiStatements` diff --git a/const.go b/const.go index 88cfff3fd..2570b23fe 100644 --- a/const.go +++ b/const.go @@ -9,7 +9,8 @@ package mysql const ( - minProtocolVersion byte = 10 + defaultMaxAllowedPacket = 4 << 20 // 4 MiB + minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" ) diff --git a/driver_test.go b/driver_test.go index bc0386a09..27b067dff 100644 --- a/driver_test.go +++ b/driver_test.go @@ -964,7 +964,7 @@ func TestUint64(t *testing.T) { } func TestLongData(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) { var maxAllowedPacketSize int err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) if err != nil { diff --git a/dsn.go b/dsn.go index 5ebd1d9f7..e3ead3ce5 100644 --- a/dsn.go +++ b/dsn.go @@ -65,6 +65,7 @@ func NewConfig() *Config { return &Config{ Collation: defaultCollation, Loc: time.UTC, + MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, } } @@ -275,7 +276,7 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.WriteTimeout.String()) } - if cfg.MaxAllowedPacket > 0 { + if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { if hasParam { buf.WriteString("&maxAllowedPacket=") } else { diff --git a/dsn_test.go b/dsn_test.go index af28da351..07b223f6b 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -22,55 +22,55 @@ var testDSNs = []struct { out *Config }{{ "username:password@protocol(address)/dbname?param=value", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true, ColumnsWithAlias: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true}, }, { "user@unix(/path/to/socket)/dbname?charset=utf8", - &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true, TLSConfig: "true"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true, TLSConfig: "skip-verify"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"}, }, { "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216}, }, { - "user:password@/dbname?allowNativePasswords=false", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: false}, + "user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", - &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local, AllowNativePasswords: true}, + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "@/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "user:p@/ssword@/", - &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "unix/?arg=%2Fsome%2Fpath.ext", - &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "tcp(127.0.0.1)/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, { "tcp(de:ad:be:ef::ca:fe)/dbname", - &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, AllowNativePasswords: true}, + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, } @@ -233,16 +233,14 @@ func TestDSNUnsafeCollation(t *testing.T) { func TestParamsAreSorted(t *testing.T) { expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo" - dsn := &Config{ - DBName: "dbname", - InterpolateParams: true, - AllowNativePasswords: true, - Params: map[string]string{ - "quux": "loo", - "foobar": "baz", - }, + cfg := NewConfig() + cfg.DBName = "dbname" + cfg.InterpolateParams = true + cfg.Params = map[string]string{ + "quux": "loo", + "foobar": "baz", } - actual := dsn.FormatDSN() + actual := cfg.FormatDSN() if actual != expected { t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual) } From b3a093e1ccb62923917ed9c6dc2374ae041ff330 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Mon, 16 Oct 2017 22:44:03 +0200 Subject: [PATCH 08/33] packets: do not call function on nulled value (#678) --- packets.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packets.go b/packets.go index ff6b1394d..40b7f1115 100644 --- a/packets.go +++ b/packets.go @@ -1155,10 +1155,11 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } return io.EOF } + mc := rows.mc rows.mc = nil // Error otherwise - return rows.mc.handleErrorPacket(data) + return mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] From 5eaa5ff08a12e4bf29321fdcc92afd1c4d21e3f7 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 17 Oct 2017 14:45:56 +0200 Subject: [PATCH 09/33] ColumnType interfaces (#667) * rows: implement driver.RowsColumnTypeScanType Implementation for time.Time not yet complete! * rows: implement driver.RowsColumnTypeNullable * rows: move fields related code to fields.go * fields: use NullTime for nullable datetime fields * fields: make fieldType its own type * rows: implement driver.RowsColumnTypeDatabaseTypeName * fields: fix copyright year * rows: compile time interface implementation checks * rows: move tests to versioned driver test files * rows: cache parseTime in resultSet instead of mysqlConn * fields: fix string and time types * rows: implement ColumnTypeLength * rows: implement ColumnTypePrecisionScale * rows: fix ColumnTypeNullable * rows: ColumnTypes tests part1 * rows: use keyed composite literals in ColumnTypes tests * rows: ColumnTypes tests part2 * rows: always use NullTime as ScanType for datetime * rows: avoid errors through rounding of time values * rows: remove parseTime cache * fields: remove unused scanTypes * rows: fix ColumnTypePrecisionScale implementation * fields: sort types alphabetical * rows: remove ColumnTypeLength implementation for now * README: document ColumnType Support --- README.md | 14 ++- connection.go | 1 + const.go | 6 +- driver_go18_test.go | 220 ++++++++++++++++++++++++++++++++++++++++++++ driver_test.go | 6 ++ fields.go | 140 ++++++++++++++++++++++++++++ packets.go | 23 +++-- rows.go | 51 ++++++++-- 8 files changed, 436 insertions(+), 25 deletions(-) create mode 100644 fields.go diff --git a/README.md b/README.md index 6a306bb30..f6eb0b0d2 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,11 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * [Parameters](#parameters) * [Examples](#examples) * [Connection pool and timeouts](#connection-pool-and-timeouts) + * [context.Context Support](#contextcontext-support) + * [ColumnType Support](#columntype-support) * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) * [time.Time support](#timetime-support) * [Unicode support](#unicode-support) - * [context.Context Support](#contextcontext-support) * [Testing / Development](#testing--development) * [License](#license) @@ -400,6 +401,13 @@ user:password@/ ### Connection pool and timeouts The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. +## `ColumnType` Support +This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. + +## `context.Context` Support +Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. +See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. + ### `LOAD DATA LOCAL INFILE` support For this feature you need direct access to the package. Therefore you must change the import path (no `_`): @@ -433,10 +441,6 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. -## `context.Context` Support -Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. -See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. - ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. diff --git a/connection.go b/connection.go index b31d63d7e..3a30c46a9 100644 --- a/connection.go +++ b/connection.go @@ -406,6 +406,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) return nil, err } } + // Columns rows.rs.columns, err = mc.readColumns(resLen) return rows, err diff --git a/const.go b/const.go index 2570b23fe..4a19ca523 100644 --- a/const.go +++ b/const.go @@ -88,8 +88,10 @@ const ( ) // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType +type fieldType byte + const ( - fieldTypeDecimal byte = iota + fieldTypeDecimal fieldType = iota fieldTypeTiny fieldTypeShort fieldTypeLong @@ -108,7 +110,7 @@ const ( fieldTypeBit ) const ( - fieldTypeJSON byte = iota + 0xf5 + fieldTypeJSON fieldType = iota + 0xf5 fieldTypeNewDecimal fieldTypeEnum fieldTypeSet diff --git a/driver_go18_test.go b/driver_go18_test.go index 4962838f2..953adeb8a 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -15,6 +15,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "math" "reflect" "testing" "time" @@ -35,6 +36,22 @@ var ( _ driver.StmtQueryContext = &mysqlStmt{} ) +// Ensure that all the driver interfaces are implemented +var ( + // _ driver.RowsColumnTypeLength = &binaryRows{} + // _ driver.RowsColumnTypeLength = &textRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &textRows{} + _ driver.RowsColumnTypeNullable = &binaryRows{} + _ driver.RowsColumnTypeNullable = &textRows{} + _ driver.RowsColumnTypePrecisionScale = &binaryRows{} + _ driver.RowsColumnTypePrecisionScale = &textRows{} + _ driver.RowsColumnTypeScanType = &binaryRows{} + _ driver.RowsColumnTypeScanType = &textRows{} + _ driver.RowsNextResultSet = &binaryRows{} + _ driver.RowsNextResultSet = &textRows{} +) + func TestMultiResultSet(t *testing.T) { type result struct { values [][]int @@ -558,3 +575,206 @@ func TestContextBeginReadOnly(t *testing.T) { } }) } + +func TestRowsColumnTypes(t *testing.T) { + niNULL := sql.NullInt64{Int64: 0, Valid: false} + ni0 := sql.NullInt64{Int64: 0, Valid: true} + ni1 := sql.NullInt64{Int64: 1, Valid: true} + ni42 := sql.NullInt64{Int64: 42, Valid: true} + nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} + nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} + nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} + nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} + nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} + nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} + nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + rbNULL := sql.RawBytes(nil) + rb0 := sql.RawBytes("0") + rb42 := sql.RawBytes("42") + rbTest := sql.RawBytes("Test") + + var columns = []struct { + name string + fieldType string // type used when creating table schema + databaseTypeName string // actual type used by MySQL + scanType reflect.Type + nullable bool + precision int64 // 0 if not ok + scale int64 + valuesIn [3]string + valuesOut [3]interface{} + }{ + {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, + {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, + {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}}, + {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}}, + {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, + {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, + {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, + {"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, + {"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, + {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, + {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}}, + {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}}, + {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}}, + {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}}, + {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}}, + {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}}, + {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, + {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"textnull", "TEXT", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"longtext", "LONGTEXT NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, + {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, + {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, + } + + schema := "" + values1 := "" + values2 := "" + values3 := "" + for _, column := range columns { + schema += fmt.Sprintf("`%s` %s, ", column.name, column.fieldType) + values1 += column.valuesIn[0] + ", " + values2 += column.valuesIn[1] + ", " + values3 += column.valuesIn[2] + ", " + } + schema = schema[:len(schema)-2] + values1 = values1[:len(values1)-2] + values2 = values2[:len(values2)-2] + values3 = values3[:len(values3)-2] + + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (" + schema + ")") + dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") + + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query: %v", err) + } + + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + if len(tt) != len(columns) { + t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + column := columns[i] + + // Name + name := tp.Name() + if name != column.name { + t.Errorf("column name mismatch %s != %s", name, column.name) + continue + } + + // DatabaseTypeName + databaseTypeName := tp.DatabaseTypeName() + if databaseTypeName != column.databaseTypeName { + t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) + continue + } + + // ScanType + scanType := tp.ScanType() + if scanType != column.scanType { + if scanType == nil { + t.Errorf("scantype is null for column %q", name) + } else { + t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) + } + continue + } + types[i] = scanType + + // Nullable + nullable, ok := tp.Nullable() + if !ok { + t.Errorf("nullable not ok %q", name) + continue + } + if nullable != column.nullable { + t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + } + + // Length + // length, ok := tp.Length() + // if length != column.length { + // if !ok { + // t.Errorf("length not ok for column %q", name) + // } else { + // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) + // } + // continue + // } + + // Precision and Scale + precision, scale, ok := tp.DecimalSize() + if precision != column.precision { + if !ok { + t.Errorf("precision not ok for column %q", name) + } else { + t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) + } + continue + } + if scale != column.scale { + if !ok { + t.Errorf("scale not ok for column %q", name) + } else { + t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) + } + continue + } + } + + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + i := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + for j := range values { + value := reflect.ValueOf(values[j]).Elem().Interface() + if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { + if columns[j].scanType == scanTypeRawBytes { + t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) + } else { + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) + } + } + } + i++ + } + if i != 3 { + t.Errorf("expected 3 rows, got %d", i) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } + }) + } +} diff --git a/driver_test.go b/driver_test.go index 27b067dff..53e70dab7 100644 --- a/driver_test.go +++ b/driver_test.go @@ -27,6 +27,12 @@ import ( "time" ) +// Ensure that all the driver interfaces are implemented +var ( + _ driver.Rows = &binaryRows{} + _ driver.Rows = &textRows{} +) + var ( user string pass string diff --git a/fields.go b/fields.go new file mode 100644 index 000000000..cded986d2 --- /dev/null +++ b/fields.go @@ -0,0 +1,140 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "reflect" +) + +var typeDatabaseName = map[fieldType]string{ + fieldTypeBit: "BIT", + fieldTypeBLOB: "BLOB", + fieldTypeDate: "DATE", + fieldTypeDateTime: "DATETIME", + fieldTypeDecimal: "DECIMAL", + fieldTypeDouble: "DOUBLE", + fieldTypeEnum: "ENUM", + fieldTypeFloat: "FLOAT", + fieldTypeGeometry: "GEOMETRY", + fieldTypeInt24: "MEDIUMINT", + fieldTypeJSON: "JSON", + fieldTypeLong: "INT", + fieldTypeLongBLOB: "LONGBLOB", + fieldTypeLongLong: "BIGINT", + fieldTypeMediumBLOB: "MEDIUMBLOB", + fieldTypeNewDate: "DATE", + fieldTypeNewDecimal: "DECIMAL", + fieldTypeNULL: "NULL", + fieldTypeSet: "SET", + fieldTypeShort: "SMALLINT", + fieldTypeString: "CHAR", + fieldTypeTime: "TIME", + fieldTypeTimestamp: "TIMESTAMP", + fieldTypeTiny: "TINYINT", + fieldTypeTinyBLOB: "TINYBLOB", + fieldTypeVarChar: "VARCHAR", + fieldTypeVarString: "VARCHAR", + fieldTypeYear: "YEAR", +} + +var ( + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) +) + +type mysqlField struct { + tableName string + name string + length uint32 + flags fieldFlag + fieldType fieldType + decimals byte +} + +func (mf *mysqlField) scanType() reflect.Type { + switch mf.fieldType { + case fieldTypeTiny: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint8 + } + return scanTypeInt8 + } + return scanTypeNullInt + + case fieldTypeShort, fieldTypeYear: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint16 + } + return scanTypeInt16 + } + return scanTypeNullInt + + case fieldTypeInt24, fieldTypeLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint32 + } + return scanTypeInt32 + } + return scanTypeNullInt + + case fieldTypeLongLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint64 + } + return scanTypeInt64 + } + return scanTypeNullInt + + case fieldTypeFloat: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat32 + } + return scanTypeNullFloat + + case fieldTypeDouble: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat64 + } + return scanTypeNullFloat + + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, + fieldTypeTime: + return scanTypeRawBytes + + case fieldTypeDate, fieldTypeNewDate, + fieldTypeTimestamp, fieldTypeDateTime: + // NullTime is always returned for more consistent behavior as it can + // handle both cases of parseTime regardless if the field is nullable. + return scanTypeNullTime + + default: + return scanTypeUnknown + } +} diff --git a/packets.go b/packets.go index 40b7f1115..97afd0abc 100644 --- a/packets.go +++ b/packets.go @@ -708,11 +708,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Filler [uint8] // Charset [charset, collation uint8] + pos += n + 1 + 2 + // Length [uint32] - pos += n + 1 + 2 + 4 + columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) + pos += 4 // Field type [uint8] - columns[i].fieldType = data[pos] + columns[i].fieldType = fieldType(data[pos]) pos++ // Flags [uint16] @@ -992,7 +995,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // build NULL-bitmap if arg == nil { nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 continue } @@ -1000,7 +1003,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // cache types and values switch v := arg.(type) { case int64: - paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -1016,7 +1019,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case float64: - paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -1032,7 +1035,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case bool: - paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i+1] = 0x00 if v { @@ -1044,7 +1047,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case []byte: // Common case (non-nil value) first if v != nil { - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { @@ -1062,11 +1065,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Handle []byte(nil) as a NULL value nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 case string: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { @@ -1081,7 +1084,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case time.Time: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 var a [64]byte diff --git a/rows.go b/rows.go index c7f5ee26c..18f41693e 100644 --- a/rows.go +++ b/rows.go @@ -11,16 +11,10 @@ package mysql import ( "database/sql/driver" "io" + "math" + "reflect" ) -type mysqlField struct { - tableName string - name string - flags fieldFlag - fieldType byte - decimals byte -} - type resultSet struct { columns []mysqlField columnNames []string @@ -65,6 +59,47 @@ func (rows *mysqlRows) Columns() []string { return columns } +func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { + if name, ok := typeDatabaseName[rows.rs.columns[i].fieldType]; ok { + return name + } + return "" +} + +// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { +// return int64(rows.rs.columns[i].length), true +// } + +func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { + return rows.rs.columns[i].flags&flagNotNULL == 0, true +} + +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { + column := rows.rs.columns[i] + decimals := int64(column.decimals) + + switch column.fieldType { + case fieldTypeDecimal, fieldTypeNewDecimal: + if decimals > 0 { + return int64(column.length) - 2, decimals, true + } + return int64(column.length) - 1, decimals, true + case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: + return decimals, decimals, true + case fieldTypeFloat, fieldTypeDouble: + if decimals == 0x1f { + return math.MaxInt64, math.MaxInt64, true + } + return math.MaxInt64, decimals, true + } + + return 0, 0, false +} + +func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { + return rows.rs.columns[i].scanType() +} + func (rows *mysqlRows) Close() (err error) { if f := rows.finish; f != nil { f() From ee460286f5d798ae53974561cc2e6827a78dd0ef Mon Sep 17 00:00:00 2001 From: Jeffrey Charles Date: Tue, 17 Oct 2017 13:10:23 -0400 Subject: [PATCH 10/33] Add Aurora errno to rejectReadOnly check (#634) AWS Aurora returns a 1290 after failing over requiring the connection to be closed and opened again to be able to perform writes. --- AUTHORS | 1 + README.md | 7 ++++++- packets.go | 3 ++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index 0a936b1f6..c98ef9dbd 100644 --- a/AUTHORS +++ b/AUTHORS @@ -35,6 +35,7 @@ INADA Naoki Jacek Szwec James Harr Jeff Hodges +Jeffrey Charles Jian Zhen Joshua Prunier Julien Lefevre diff --git a/README.md b/README.md index f6eb0b0d2..d24aaa0f0 100644 --- a/README.md +++ b/README.md @@ -279,7 +279,7 @@ Default: false ``` -`rejectreadOnly=true` causes the driver to reject read-only connections. This +`rejectReadOnly=true` causes the driver to reject read-only connections. This is for a possible race condition during an automatic failover, where the mysql client gets connected to a read-only replica after the failover. @@ -294,6 +294,11 @@ If you are not relying on read-only transactions to reject writes that aren't supposed to happen, setting this on some MySQL providers (such as AWS Aurora) is safer for failovers. +Note that ERROR 1290 can be returned for a `read-only` server and this option will +cause a retry for that error. However the same error number is used for some +other cases. You should ensure your application will never cause an ERROR 1290 +except for `read-only` mode when enabling this option. + ##### `timeout` diff --git a/packets.go b/packets.go index 97afd0abc..7bd2dd309 100644 --- a/packets.go +++ b/packets.go @@ -580,7 +580,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { errno := binary.LittleEndian.Uint16(data[1:3]) // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION - if errno == 1792 && mc.cfg.RejectReadOnly { + // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) + if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { // Oops; we are connected to a read-only connection, and won't be able // to issue any write statements. Since RejectReadOnly is configured, // we throw away this connection hoping this one would have write From 93aed7307deff9a0a6dc64c80b7862c29bc67c8d Mon Sep 17 00:00:00 2001 From: Jeff Hodges Date: Tue, 17 Oct 2017 11:16:16 -0700 Subject: [PATCH 11/33] allow successful TravisCI runs in forks (#639) Most forks won't be in goveralls and so this command in travis.yml was, previously, failing and causing the build to fail. Now, it doesn't! --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index fa0b2c933..6369281e8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -90,4 +90,5 @@ script: - go test -v -covermode=count -coverprofile=coverage.out - go vet ./... - test -z "$(gofmt -d -s . | tee /dev/stderr)" +after_script: - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci From 4f10ee537a00db3ae88fa835f5e687a71639fe76 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Sun, 12 Nov 2017 22:30:34 +0100 Subject: [PATCH 12/33] Drop support for Go 1.6 and lower (#696) * Drop support for Go 1.6 and lower * Remove cloneTLSConfig for legacy Go versions --- .travis.yml | 2 -- README.md | 2 +- utils_legacy.go | 18 ------------------ 3 files changed, 1 insertion(+), 21 deletions(-) delete mode 100644 utils_legacy.go diff --git a/.travis.yml b/.travis.yml index 6369281e8..64b06a70c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,6 @@ sudo: false language: go go: - - 1.5 - - 1.6 - 1.7 - 1.8 - 1.9 diff --git a/README.md b/README.md index d24aaa0f0..299198d53 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Optional placeholder interpolation ## Requirements - * Go 1.5 or higher + * Go 1.7 or higher. We aim to support the 3 latest versions of Go. * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) --------------------------------------- diff --git a/utils_legacy.go b/utils_legacy.go deleted file mode 100644 index a03b10de2..000000000 --- a/utils_legacy.go +++ /dev/null @@ -1,18 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build !go1.7 - -package mysql - -import "crypto/tls" - -func cloneTLSConfig(c *tls.Config) *tls.Config { - clone := *c - return &clone -} From 59b0f90fea7003118587750d3590ebfb0cfc3d4f Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 14 Nov 2017 09:18:14 +0100 Subject: [PATCH 13/33] Make gofmt happy (#704) --- driver_test.go | 1 - dsn.go | 1 - utils.go | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/driver_test.go b/driver_test.go index 53e70dab7..f6965b191 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1375,7 +1375,6 @@ func TestTimezoneConversion(t *testing.T) { // Regression test for timezone handling tzTest := func(dbt *DBTest) { - // Create table dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") diff --git a/dsn.go b/dsn.go index e3ead3ce5..418bc86b9 100644 --- a/dsn.go +++ b/dsn.go @@ -399,7 +399,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { // cfg params switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool diff --git a/utils.go b/utils.go index 82da83099..a92a4029b 100644 --- a/utils.go +++ b/utils.go @@ -566,8 +566,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { if len(b) == 0 { return 0, true, 1 } - switch b[0] { + switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 From 3fbf53ab2f434b1d97ce2ce9395251793ad1d3f7 Mon Sep 17 00:00:00 2001 From: Daniel Montoya Date: Wed, 15 Nov 2017 16:37:47 -0600 Subject: [PATCH 14/33] Added support for custom string types in ConvertValue. (#623) * Added support for custom string types. * Add author name * Added license header * Added a newline to force a commit. * Remove newline. --- AUTHORS | 1 + statement.go | 2 ++ statement_test.go | 21 +++++++++++++++++++++ 3 files changed, 24 insertions(+) create mode 100644 statement_test.go diff --git a/AUTHORS b/AUTHORS index c98ef9dbd..780561a98 100644 --- a/AUTHORS +++ b/AUTHORS @@ -19,6 +19,7 @@ B Lamarche Bulat Gaifullin Carlos Nieto Chris Moos +Daniel Montoya Daniel Nichter Daniël van Eeden Dave Protasowski diff --git a/statement.go b/statement.go index ae223507f..628174b64 100644 --- a/statement.go +++ b/statement.go @@ -157,6 +157,8 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return int64(u64), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil + case reflect.String: + return rv.String(), nil } return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 000000000..8de4a8b26 --- /dev/null +++ b/statement_test.go @@ -0,0 +1,21 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import "testing" + +type customString string + +func TestConvertValueCustomTypes(t *testing.T) { + var cstr customString = "string" + c := converter{} + if _, err := c.ConvertValue(cstr); err != nil { + t.Errorf("custom string type should be valid") + } +} From f9c6a2cea1651d4e197b8034fafa768bbd44223f Mon Sep 17 00:00:00 2001 From: Justin Li Date: Thu, 16 Nov 2017 02:25:03 -0500 Subject: [PATCH 15/33] Implement NamedValueChecker for mysqlConn (#690) * Also add conversions for additional types in ConvertValue ref https://github.com/golang/go/commit/d7c0de98a96893e5608358f7578c85be7ba12b25 --- AUTHORS | 1 + connection_go18.go | 5 ++ connection_go18_test.go | 30 ++++++++++ statement.go | 8 +++ statement_test.go | 119 +++++++++++++++++++++++++++++++++++++--- 5 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 connection_go18_test.go diff --git a/AUTHORS b/AUTHORS index 780561a98..95d14a076 100644 --- a/AUTHORS +++ b/AUTHORS @@ -41,6 +41,7 @@ Jian Zhen Joshua Prunier Julien Lefevre Julien Schmidt +Justin Li Justin Nuß Kamil Dziedzic Kevin Malachowski diff --git a/connection_go18.go b/connection_go18.go index 48a9cca64..1306b70b7 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -195,3 +195,8 @@ func (mc *mysqlConn) startWatcher() { } }() } + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} diff --git a/connection_go18_test.go b/connection_go18_test.go new file mode 100644 index 000000000..2719ab3b7 --- /dev/null +++ b/connection_go18_test.go @@ -0,0 +1,30 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "database/sql/driver" + "testing" +) + +func TestCheckNamedValue(t *testing.T) { + value := driver.NamedValue{Value: ^uint64(0)} + x := &mysqlConn{} + err := x.CheckNamedValue(&value) + + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if value.Value != "18446744073709551615" { + t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value) + } +} diff --git a/statement.go b/statement.go index 628174b64..4870a307c 100644 --- a/statement.go +++ b/statement.go @@ -157,6 +157,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return int64(u64), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } diff --git a/statement_test.go b/statement_test.go index 8de4a8b26..98a6c1933 100644 --- a/statement_test.go +++ b/statement_test.go @@ -8,14 +8,119 @@ package mysql -import "testing" +import ( + "bytes" + "testing" +) -type customString string +func TestConvertDerivedString(t *testing.T) { + type derived string -func TestConvertValueCustomTypes(t *testing.T) { - var cstr customString = "string" - c := converter{} - if _, err := c.ConvertValue(cstr); err != nil { - t.Errorf("custom string type should be valid") + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Derived string type not convertible", err) + } + + if output != "value" { + t.Fatalf("Derived string type not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedByteSlice(t *testing.T) { + type derived []uint8 + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Byte slice not convertible", err) + } + + if bytes.Compare(output.([]byte), []byte("value")) != 0 { + t.Fatalf("Byte slice not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedUnsupportedSlice(t *testing.T) { + type derived []int + + _, err := converter{}.ConvertValue(derived{1}) + if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" { + t.Fatal("Unexpected error", err) + } +} + +func TestConvertDerivedBool(t *testing.T) { + type derived bool + + output, err := converter{}.ConvertValue(derived(true)) + if err != nil { + t.Fatal("Derived bool type not convertible", err) + } + + if output != true { + t.Fatalf("Derived bool type not converted, got %#v %T", output, output) + } +} + +func TestConvertPointer(t *testing.T) { + str := "value" + + output, err := converter{}.ConvertValue(&str) + if err != nil { + t.Fatal("Pointer type not convertible", err) + } + + if output != "value" { + t.Fatalf("Pointer type not converted, got %#v %T", output, output) + } +} + +func TestConvertSignedIntegers(t *testing.T) { + values := []interface{}{ + int8(-42), + int16(-42), + int32(-42), + int64(-42), + int(-42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(-42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } +} + +func TestConvertUnsignedIntegers(t *testing.T) { + values := []interface{}{ + uint8(42), + uint16(42), + uint32(42), + uint64(42), + uint(42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } + + output, err := converter{}.ConvertValue(^uint64(0)) + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if output != "18446744073709551615" { + t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output) } } From 6046bf014ffba46d3753cfc64d1a3c9656318d8f Mon Sep 17 00:00:00 2001 From: Dave Stubbs Date: Thu, 16 Nov 2017 16:10:24 +0000 Subject: [PATCH 16/33] Fix Valuers by returning driver.ErrSkip if couldn't convert type internally (#709) Fixes #708 --- AUTHORS | 1 + connection_go18.go | 6 +++++- driver_test.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 95d14a076..66a4ad202 100644 --- a/AUTHORS +++ b/AUTHORS @@ -72,6 +72,7 @@ Zhenye Xie # Organizations Barracuda Networks, Inc. +Counting Ltd. Google Inc. Keybase Inc. Pivotal Inc. diff --git a/connection_go18.go b/connection_go18.go index 1306b70b7..65cc63ef2 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -197,6 +197,10 @@ func (mc *mysqlConn) startWatcher() { } func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { - nv.Value, err = converter{}.ConvertValue(nv.Value) + value, err := converter{}.ConvertValue(nv.Value) + if err != nil { + return driver.ErrSkip + } + nv.Value = value return } diff --git a/driver_test.go b/driver_test.go index f6965b191..392e752a3 100644 --- a/driver_test.go +++ b/driver_test.go @@ -499,6 +499,36 @@ func TestString(t *testing.T) { }) } +type testValuer struct { + value string +} + +func (tv testValuer) Value() (driver.Value, error) { + return tv.value, nil +} + +func TestValuer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuer{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + + dbt.mustExec("DROP TABLE IF EXISTS test") + }) +} + type timeTests struct { dbtype string tlayout string From 385673a27ccb40f4a14746623da6e849c80eb079 Mon Sep 17 00:00:00 2001 From: Linh Tran Tuan Date: Fri, 17 Nov 2017 14:23:23 +0700 Subject: [PATCH 17/33] statement: Fix conversion of Valuer (#710) Updates #709 Fixes #706 --- AUTHORS | 1 + connection_go18.go | 6 +----- driver_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ statement.go | 6 ++++++ 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/AUTHORS b/AUTHORS index 66a4ad202..3fc9ece3a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -47,6 +47,7 @@ Kamil Dziedzic Kevin Malachowski Lennart Rudolph Leonardo YongUk Kim +Linh Tran Tuan Lion Yang Luca Looz Lucas Liu diff --git a/connection_go18.go b/connection_go18.go index 65cc63ef2..1306b70b7 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -197,10 +197,6 @@ func (mc *mysqlConn) startWatcher() { } func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { - value, err := converter{}.ConvertValue(nv.Value) - if err != nil { - return driver.ErrSkip - } - nv.Value = value + nv.Value, err = converter{}.ConvertValue(nv.Value) return } diff --git a/driver_test.go b/driver_test.go index 392e752a3..224a24c53 100644 --- a/driver_test.go +++ b/driver_test.go @@ -529,6 +529,55 @@ func TestValuer(t *testing.T) { }) } +type testValuerWithValidation struct { + value string +} + +func (tv testValuerWithValidation) Value() (driver.Value, error) { + if len(tv.value) == 0 { + return nil, fmt.Errorf("Invalid string valuer. Value must not be empty") + } + + return tv.value, nil +} + +func TestValuerWithValidation(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuerWithValidation{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO testValuer VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM testValuer") + defer rows.Close() + + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil { + dbt.Errorf("Failed to check valuer error") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil { + dbt.Errorf("Failed to check nil") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil { + dbt.Errorf("Failed to check not valuer") + } + + dbt.mustExec("DROP TABLE IF EXISTS testValuer") + }) +} + type timeTests struct { dbtype string tlayout string diff --git a/statement.go b/statement.go index 4870a307c..98e57bcd8 100644 --- a/statement.go +++ b/statement.go @@ -137,6 +137,12 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return v, nil } + if v != nil { + if valuer, ok := v.(driver.Valuer); ok { + return valuer.Value() + } + } + rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: From 9031984e2b7bab392eb005c4c014ea66af892796 Mon Sep 17 00:00:00 2001 From: "Robert R. Russell" Date: Fri, 17 Nov 2017 05:51:24 -0600 Subject: [PATCH 18/33] Fixed imports for appengine/cloudsql (#700) * Fixed broken import for appengine/cloudsql appengine.go import path of appengine/cloudsql has changed to google.golang.org/appengine/cloudsql - Fixed. * Added my name to the AUTHORS --- AUTHORS | 1 + appengine.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 3fc9ece3a..9988284ef 100644 --- a/AUTHORS +++ b/AUTHORS @@ -61,6 +61,7 @@ Paul Bonser Peter Schultz Rebecca Chin Runrioter Wung +Robert Russell Shuode Li Soroush Pour Stan Putrya diff --git a/appengine.go b/appengine.go index 565614eef..be41f2ee6 100644 --- a/appengine.go +++ b/appengine.go @@ -11,7 +11,7 @@ package mysql import ( - "appengine/cloudsql" + "google.golang.org/appengine/cloudsql" ) func init() { From 6992fad9c49d4e7df3c8346949e44684fa826146 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 4 Dec 2017 09:43:26 +0900 Subject: [PATCH 19/33] Fix tls=true didn't work with host without port (#718) Fixes #717 --- dsn.go | 20 +++++++++----------- dsn_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/dsn.go b/dsn.go index 418bc86b9..fa50ad3c0 100644 --- a/dsn.go +++ b/dsn.go @@ -95,6 +95,15 @@ func (cfg *Config) normalize() error { cfg.Addr = ensureHavePort(cfg.Addr) } + if cfg.tls != nil { + if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + } + return nil } @@ -526,10 +535,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { if boolValue { cfg.TLSConfig = "true" cfg.tls = &tls.Config{} - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - cfg.tls.ServerName = host - } } else { cfg.TLSConfig = "false" } @@ -543,13 +548,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { } if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { - if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - tlsConfig.ServerName = host - } - } - cfg.TLSConfig = name cfg.tls = tlsConfig } else { diff --git a/dsn_test.go b/dsn_test.go index 07b223f6b..7507d1201 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -177,6 +177,34 @@ func TestDSNWithCustomTLS(t *testing.T) { DeregisterTLSConfig("utils_test") } +func TestDSNTLSConfig(t *testing.T) { + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true" + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + dsn = "tcp(example.com)/?tls=true" + cfg, err = ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + } +} + func TestDSNWithCustomTLSQueryEscape(t *testing.T) { const configKey = "&%!:" dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey) From 386f84bcc4e8e23703cc5b822191df934c85fd73 Mon Sep 17 00:00:00 2001 From: Kieron Woodhouse Date: Wed, 10 Jan 2018 11:31:24 +0000 Subject: [PATCH 20/33] Differentiate between BINARY and CHAR (#724) * Differentiate between BINARY and CHAR When looking up the database type name, we now check the character set for the following field types: * CHAR * VARCHAR * BLOB * TINYBLOB * MEDIUMBLOB * LONGBLOB If the character set is 63 (which is the binary pseudo character set), we return the binary names, which are (respectively): * BINARY * VARBINARY * BLOB * TINYBLOB * MEDIUMBLOB * LONGBLOB If any other character set is in use, we return the text names, which are (again, respectively): * CHAR * VARCHAR * TEXT * TINYTEXT * MEDIUMTEXT * LONGTEXT To facilitate this, mysqlField has been extended to include a uint8 field for character set, which is read from the appropriate packet. Column type tests have been updated to ensure coverage of binary and text types. * Increase test coverage for column types --- AUTHORS | 2 + collations.go | 1 + driver_go18_test.go | 22 ++++++++- fields.go | 112 ++++++++++++++++++++++++++++++++------------ packets.go | 6 ++- rows.go | 5 +- 6 files changed, 112 insertions(+), 36 deletions(-) diff --git a/AUTHORS b/AUTHORS index 9988284ef..5d84a6eb1 100644 --- a/AUTHORS +++ b/AUTHORS @@ -45,6 +45,7 @@ Justin Li Justin Nuß Kamil Dziedzic Kevin Malachowski +Kieron Woodhouse Lennart Rudolph Leonardo YongUk Kim Linh Tran Tuan @@ -76,6 +77,7 @@ Zhenye Xie Barracuda Networks, Inc. Counting Ltd. Google Inc. +InfoSum Ltd. Keybase Inc. Pivotal Inc. Stripe Inc. diff --git a/collations.go b/collations.go index 82079cfb9..136c9e4d1 100644 --- a/collations.go +++ b/collations.go @@ -9,6 +9,7 @@ package mysql const defaultCollation = "utf8_general_ci" +const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: diff --git a/driver_go18_test.go b/driver_go18_test.go index 953adeb8a..e461455dd 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -588,10 +588,16 @@ func TestRowsColumnTypes(t *testing.T) { nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + ndNULL := NullTime{Time: time.Time{}, Valid: false} rbNULL := sql.RawBytes(nil) rb0 := sql.RawBytes("0") rb42 := sql.RawBytes("42") rbTest := sql.RawBytes("Test") + rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00 + rbx0 := sql.RawBytes("\x00") + rbx42 := sql.RawBytes("\x42") var columns = []struct { name string @@ -604,6 +610,7 @@ func TestRowsColumnTypes(t *testing.T) { valuesIn [3]string valuesOut [3]interface{} }{ + {"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}}, {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, @@ -611,6 +618,7 @@ func TestRowsColumnTypes(t *testing.T) { {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}}, + {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}}, {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, @@ -630,11 +638,21 @@ func TestRowsColumnTypes(t *testing.T) { {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"textnull", "TEXT", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"longtext", "LONGTEXT NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}}, + {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, + {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}}, + {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}}, } schema := "" diff --git a/fields.go b/fields.go index cded986d2..e1e2ece4b 100644 --- a/fields.go +++ b/fields.go @@ -13,35 +13,88 @@ import ( "reflect" ) -var typeDatabaseName = map[fieldType]string{ - fieldTypeBit: "BIT", - fieldTypeBLOB: "BLOB", - fieldTypeDate: "DATE", - fieldTypeDateTime: "DATETIME", - fieldTypeDecimal: "DECIMAL", - fieldTypeDouble: "DOUBLE", - fieldTypeEnum: "ENUM", - fieldTypeFloat: "FLOAT", - fieldTypeGeometry: "GEOMETRY", - fieldTypeInt24: "MEDIUMINT", - fieldTypeJSON: "JSON", - fieldTypeLong: "INT", - fieldTypeLongBLOB: "LONGBLOB", - fieldTypeLongLong: "BIGINT", - fieldTypeMediumBLOB: "MEDIUMBLOB", - fieldTypeNewDate: "DATE", - fieldTypeNewDecimal: "DECIMAL", - fieldTypeNULL: "NULL", - fieldTypeSet: "SET", - fieldTypeShort: "SMALLINT", - fieldTypeString: "CHAR", - fieldTypeTime: "TIME", - fieldTypeTimestamp: "TIMESTAMP", - fieldTypeTiny: "TINYINT", - fieldTypeTinyBLOB: "TINYBLOB", - fieldTypeVarChar: "VARCHAR", - fieldTypeVarString: "VARCHAR", - fieldTypeYear: "YEAR", +func (mf *mysqlField) typeDatabaseName() string { + switch mf.fieldType { + case fieldTypeBit: + return "BIT" + case fieldTypeBLOB: + if mf.charSet != collations[binaryCollation] { + return "TEXT" + } + return "BLOB" + case fieldTypeDate: + return "DATE" + case fieldTypeDateTime: + return "DATETIME" + case fieldTypeDecimal: + return "DECIMAL" + case fieldTypeDouble: + return "DOUBLE" + case fieldTypeEnum: + return "ENUM" + case fieldTypeFloat: + return "FLOAT" + case fieldTypeGeometry: + return "GEOMETRY" + case fieldTypeInt24: + return "MEDIUMINT" + case fieldTypeJSON: + return "JSON" + case fieldTypeLong: + return "INT" + case fieldTypeLongBLOB: + if mf.charSet != collations[binaryCollation] { + return "LONGTEXT" + } + return "LONGBLOB" + case fieldTypeLongLong: + return "BIGINT" + case fieldTypeMediumBLOB: + if mf.charSet != collations[binaryCollation] { + return "MEDIUMTEXT" + } + return "MEDIUMBLOB" + case fieldTypeNewDate: + return "DATE" + case fieldTypeNewDecimal: + return "DECIMAL" + case fieldTypeNULL: + return "NULL" + case fieldTypeSet: + return "SET" + case fieldTypeShort: + return "SMALLINT" + case fieldTypeString: + if mf.charSet == collations[binaryCollation] { + return "BINARY" + } + return "CHAR" + case fieldTypeTime: + return "TIME" + case fieldTypeTimestamp: + return "TIMESTAMP" + case fieldTypeTiny: + return "TINYINT" + case fieldTypeTinyBLOB: + if mf.charSet != collations[binaryCollation] { + return "TINYTEXT" + } + return "TINYBLOB" + case fieldTypeVarChar: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeVarString: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeYear: + return "YEAR" + default: + return "" + } } var ( @@ -69,6 +122,7 @@ type mysqlField struct { flags fieldFlag fieldType fieldType decimals byte + charSet uint8 } func (mf *mysqlField) scanType() reflect.Type { diff --git a/packets.go b/packets.go index 7bd2dd309..36ce691c5 100644 --- a/packets.go +++ b/packets.go @@ -706,10 +706,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if err != nil { return nil, err } + pos += n // Filler [uint8] + pos++ + // Charset [charset, collation uint8] - pos += n + 1 + 2 + columns[i].charSet = data[pos] + pos += 2 // Length [uint32] columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) diff --git a/rows.go b/rows.go index 18f41693e..d3b1e2822 100644 --- a/rows.go +++ b/rows.go @@ -60,10 +60,7 @@ func (rows *mysqlRows) Columns() []string { } func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { - if name, ok := typeDatabaseName[rows.rs.columns[i].fieldType]; ok { - return name - } - return "" + return rows.rs.columns[i].typeDatabaseName() } // func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { From f853432d62faa9c9265c1454f7b1bae07f4b2a6e Mon Sep 17 00:00:00 2001 From: Alexey Palazhchenko Date: Wed, 10 Jan 2018 13:44:24 +0200 Subject: [PATCH 21/33] Test with latest Go patch versions (#693) --- .travis.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.travis.yml b/.travis.yml index 64b06a70c..e922f9187 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,10 +1,10 @@ sudo: false language: go go: - - 1.7 - - 1.8 - - 1.9 - - tip + - 1.7.x + - 1.8.x + - 1.9.x + - master before_install: - go get golang.org/x/tools/cmd/cover @@ -21,7 +21,7 @@ matrix: - env: DB=MYSQL57 sudo: required dist: trusty - go: 1.9 + go: 1.9.x services: - docker before_install: @@ -43,7 +43,7 @@ matrix: - env: DB=MARIA55 sudo: required dist: trusty - go: 1.9 + go: 1.9.x services: - docker before_install: @@ -65,7 +65,7 @@ matrix: - env: DB=MARIA10_1 sudo: required dist: trusty - go: 1.9 + go: 1.9.x services: - docker before_install: From d1a8b86f7fef0f773e55ca15defb2347be22a106 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Sun, 14 Jan 2018 05:07:44 +0900 Subject: [PATCH 22/33] Fix prepared statement (#734) * Fix prepared statement When there are many args and maxAllowedPacket is not enough, writeExecutePacket() attempted to use STMT_LONG_DATA even for 0byte string. But writeCommandLongData() doesn't support 0byte data. So it caused to send malfold packet. This commit loosen threshold for using STMT_LONG_DATA. * Change minimum size of LONG_DATA to 64byte * Add test which reproduce issue 730 * TestPreparedManyCols test only numParams = 65535 case * s/as possible// --- driver_test.go | 17 ++++++++++++++--- packets.go | 10 ++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/driver_test.go b/driver_test.go index 224a24c53..7877aa979 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1669,8 +1669,9 @@ func TestStmtMultiRows(t *testing.T) { // Regression test for // * more than 32 NULL parameters (issue 209) // * more parameters than fit into the buffer (issue 201) +// * parameters * 64 > max_allowed_packet (issue 734) func TestPreparedManyCols(t *testing.T) { - const numParams = defaultBufSize + numParams := 65535 runTests(t, dsn, func(dbt *DBTest) { query := "SELECT ?" + strings.Repeat(",?", numParams-1) stmt, err := dbt.db.Prepare(query) @@ -1678,15 +1679,25 @@ func TestPreparedManyCols(t *testing.T) { dbt.Fatal(err) } defer stmt.Close() + // create more parameters than fit into the buffer // which will take nil-values params := make([]interface{}, numParams) rows, err := stmt.Query(params...) if err != nil { - stmt.Close() dbt.Fatal(err) } - defer rows.Close() + rows.Close() + + // Create 0byte string which we can't send via STMT_LONG_DATA. + for i := 0; i < numParams; i++ { + params[i] = "" + } + rows, err = stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() }) } diff --git a/packets.go b/packets.go index 36ce691c5..e6d8e4e88 100644 --- a/packets.go +++ b/packets.go @@ -927,6 +927,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc + // Determine threshould dynamically to avoid packet size shortage. + longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) + if longDataSize < 64 { + longDataSize = 64 + } + // Reset packet-sequence mc.sequence = 0 mc.compressionSequence = 0 @@ -1055,7 +1061,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -1077,7 +1083,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) From 31679208840e1a2a05db26638c7bcda4ff362bf1 Mon Sep 17 00:00:00 2001 From: Reed Allman Date: Wed, 24 Jan 2018 21:47:45 -0800 Subject: [PATCH 23/33] driver.ErrBadConn when init packet read fails (#736) Thank you! --- AUTHORS | 1 + packets.go | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/AUTHORS b/AUTHORS index 5d84a6eb1..d9144aece 100644 --- a/AUTHORS +++ b/AUTHORS @@ -61,6 +61,7 @@ oscarzhao Paul Bonser Peter Schultz Rebecca Chin +Reed Allman Runrioter Wung Robert Russell Shuode Li diff --git a/packets.go b/packets.go index e6d8e4e88..2e9cb4984 100644 --- a/packets.go +++ b/packets.go @@ -157,6 +157,11 @@ func (mc *mysqlConn) writePacket(data []byte) error { func (mc *mysqlConn) readInitPacket() ([]byte, error) { data, err := mc.readPacket() if err != nil { + // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since + // in connection initialization we don't risk retrying non-idempotent actions. + if err == ErrInvalidConn { + return nil, driver.ErrBadConn + } return nil, err } From fb33a2cb2ede88e3222b62f272909216a1121e5b Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Mon, 31 Jul 2017 16:43:52 -0400 Subject: [PATCH 24/33] packets: implemented compression protocol --- compress.go | 3 ++- dsn.go | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/compress.go b/compress.go index 3425c72fa..56b07362b 100644 --- a/compress.go +++ b/compress.go @@ -10,7 +10,6 @@ const ( minCompressLength = 50 ) - type packetReader interface { readNext(need int) ([]byte, error) } @@ -117,6 +116,7 @@ func (cr *compressedReader) uncompressPacket() error { // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate for lenRead < uncompressedLength { n, err := cr.zr.Read(data[lenRead:]) + lenRead += n if err == io.EOF { @@ -168,6 +168,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } err = cw.writeToNetwork(compressedPayload, payloadLen) + if err != nil { return 0, err } diff --git a/dsn.go b/dsn.go index fa50ad3c0..92f137daa 100644 --- a/dsn.go +++ b/dsn.go @@ -53,6 +53,7 @@ type Config struct { AllowOldPasswords bool // Allows the old insecure password method ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names + Compression bool // Compress packets InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time From f1746058a298b62cc541c3475df4ca4e6569d4c2 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Wed, 16 Aug 2017 13:55:33 -0400 Subject: [PATCH 25/33] packets: implemented compression protocol CR changes --- compress.go | 1 - 1 file changed, 1 deletion(-) diff --git a/compress.go b/compress.go index 56b07362b..9339e0eca 100644 --- a/compress.go +++ b/compress.go @@ -116,7 +116,6 @@ func (cr *compressedReader) uncompressPacket() error { // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate for lenRead < uncompressedLength { n, err := cr.zr.Read(data[lenRead:]) - lenRead += n if err == io.EOF { From dbd1e2befc161ab78bf0aa36e8e8bbd2e05e9ad7 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Fri, 23 Mar 2018 11:22:20 -0400 Subject: [PATCH 26/33] third code review changes --- compress.go | 12 ++++-------- connection.go | 8 ++++++-- driver.go | 2 +- dsn.go | 4 ++-- packets.go | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/compress.go b/compress.go index 9339e0eca..74b293a58 100644 --- a/compress.go +++ b/compress.go @@ -10,10 +10,6 @@ const ( minCompressLength = 50 ) -type packetReader interface { - readNext(need int) ([]byte, error) -} - type compressedReader struct { buf packetReader bytesBuf []byte @@ -151,7 +147,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { payloadLen := len(payload) bytesBuf := &bytes.Buffer{} - bytesBuf.Write(blankHeader) + bytesBuf.Write(blankHeader) cw.zw.Reset(bytesBuf) _, err := cw.zw.Write(payload) if err != nil { @@ -162,7 +158,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { // if compression expands the payload, do not compress compressedPayload := bytesBuf.Bytes() if len(compressedPayload) > maxPayloadLength { - compressedPayload = append(blankHeader, payload...) + compressedPayload = append(blankHeader, payload...) payloadLen = 0 } @@ -180,7 +176,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { // do not attempt compression if packet is too small if payloadLen < minCompressLength { - err := cw.writeToNetwork(append(blankHeader, data...), 0) + err := cw.writeToNetwork(append(blankHeader, data...), 0) if err != nil { return 0, err } @@ -199,7 +195,7 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { compressedPayload := bytesBuf.Bytes() if len(compressedPayload) > len(data) { - compressedPayload = append(blankHeader, data...) + compressedPayload = append(blankHeader, data...) payloadLen = 0 } diff --git a/connection.go b/connection.go index 3a30c46a9..cc802fa42 100644 --- a/connection.go +++ b/connection.go @@ -30,6 +30,8 @@ type mysqlContext interface { type mysqlConn struct { buf buffer netConn net.Conn + reader packetReader + writer io.Writer affectedRows uint64 insertId uint64 cfg *Config @@ -41,8 +43,6 @@ type mysqlConn struct { sequence uint8 compressionSequence uint8 parseTime bool - reader packetReader - writer io.Writer // for context support (Go 1.8+) watching bool @@ -53,6 +53,10 @@ type mysqlConn struct { closed atomicBool // set when conn is closed, before closech is closed } +type packetReader interface { + readNext(need int) ([]byte, error) +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { for param, val := range mc.cfg.Params { diff --git a/driver.go b/driver.go index 86d38f70d..636ea1fb3 100644 --- a/driver.go +++ b/driver.go @@ -123,7 +123,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return nil, err } - if mc.cfg.Compression { + if mc.cfg.Compress { mc.reader = NewCompressedReader(&mc.buf, mc) mc.writer = NewCompressedWriter(mc.writer, mc) } diff --git a/dsn.go b/dsn.go index 92f137daa..b7e9c5495 100644 --- a/dsn.go +++ b/dsn.go @@ -53,7 +53,7 @@ type Config struct { AllowOldPasswords bool // Allows the old insecure password method ClientFoundRows bool // Return number of matching rows instead of rows changed ColumnsWithAlias bool // Prepend table alias to column names - Compression bool // Compress packets + Compress bool // Compress packets InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time @@ -464,7 +464,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": var isBool bool - cfg.Compression, isBool = readBool(value) + cfg.Compress, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } diff --git a/packets.go b/packets.go index 2e9cb4984..0303c426d 100644 --- a/packets.go +++ b/packets.go @@ -258,7 +258,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientFoundRows } - if mc.cfg.Compression { + if mc.cfg.Compress { clientFlags |= clientCompress } From 3e12e32d9970baaaa6316d83e9e934e72e4564eb Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Fri, 23 Mar 2018 11:58:56 -0400 Subject: [PATCH 27/33] PR 649: minor cleanup --- compress.go | 4 ++-- compress_test.go | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/compress.go b/compress.go index 74b293a58..719e3625d 100644 --- a/compress.go +++ b/compress.go @@ -97,7 +97,7 @@ func (cr *compressedReader) uncompressPacket() error { defer cr.zr.Close() - //use existing capacity in bytesBuf if possible + // use existing capacity in bytesBuf if possible offset := len(cr.bytesBuf) if cap(cr.bytesBuf)-offset < uncompressedLength { old := cr.bytesBuf @@ -220,7 +220,7 @@ func (cw *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error data[3] = cw.mc.compressionSequence - //this value is never greater than maxPayloadLength + // this value is never greater than maxPayloadLength data[4] = byte(0xff & uncomprLength) data[5] = byte(0xff & (uncomprLength >> 8)) data[6] = byte(0xff & (uncomprLength >> 16)) diff --git a/compress_test.go b/compress_test.go index c626ff3ee..d497ed56d 100644 --- a/compress_test.go +++ b/compress_test.go @@ -208,13 +208,10 @@ func TestRoundtrip(t *testing.T) { for _, test := range tests { s := fmt.Sprintf("Test roundtrip with %s", test.desc) - //t.Run(s, func(t *testing.T) { uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) if bytes.Compare(uncompressed, test.uncompressed) != 0 { t.Fatal(fmt.Sprintf("%s: roundtrip failed", s)) } - - //}) } } From 60bdaec793f67557d38e805047066911e9d8cab5 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 26 Mar 2018 19:01:55 +0900 Subject: [PATCH 28/33] Sort AUTHORS --- AUTHORS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 6274cd5e1..291c0f2fd 100644 --- a/AUTHORS +++ b/AUTHORS @@ -64,8 +64,8 @@ Paul Bonser Peter Schultz Rebecca Chin Reed Allman -Runrioter Wung Robert Russell +Runrioter Wung Shuode Li Soroush Pour Stan Putrya From 422ab6f48ea938b28641d9cb23a67e51190f19bb Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Mon, 26 Mar 2018 19:28:02 +0900 Subject: [PATCH 29/33] Update dsn.go --- dsn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/dsn.go b/dsn.go index b7e9c5495..82d15a8fb 100644 --- a/dsn.go +++ b/dsn.go @@ -58,7 +58,6 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - Compression bool // Compress packets } // NewConfig creates a new Config and sets default values. From 26ea544c317565f3de751c07a7cfd7a020a03552 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Mon, 23 Jul 2018 12:25:13 -0400 Subject: [PATCH 30/33] cr4 changes --- compress.go | 6 ++---- compress_test.go | 4 ++-- driver.go | 4 ++-- packets.go | 13 +++++++------ 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/compress.go b/compress.go index 719e3625d..eaa23fdd5 100644 --- a/compress.go +++ b/compress.go @@ -23,7 +23,7 @@ type compressedWriter struct { zw *zlib.Writer } -func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { +func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { return &compressedReader{ buf: buf, bytesBuf: make([]byte, 0), @@ -31,7 +31,7 @@ func NewCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { } } -func NewCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { +func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { return &compressedWriter{ connWriter: connWriter, mc: mc, @@ -206,11 +206,9 @@ func (cw *compressedWriter) Write(data []byte) (int, error) { } return totalBytes, nil - } func (cw *compressedWriter) writeToNetwork(data []byte, uncomprLength int) error { - comprLength := len(data) - 7 // compression header diff --git a/compress_test.go b/compress_test.go index d497ed56d..ce442ed37 100644 --- a/compress_test.go +++ b/compress_test.go @@ -48,7 +48,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by var b bytes.Buffer connWriter := &b - cw := NewCompressedWriter(connWriter, mc) + cw := newCompressedWriter(connWriter, mc) n, err := cw.Write(uncompressedPacket) @@ -90,7 +90,7 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS mockConnReader := bytes.NewReader(compressedPacket) mockBuf := newMockBuf(mockConnReader) - cr := NewCompressedReader(mockBuf, mc) + cr := newCompressedReader(mockBuf, mc) uncompressedPacket, err := cr.readNext(expSize) if err != nil { diff --git a/driver.go b/driver.go index 636ea1fb3..35f104d9f 100644 --- a/driver.go +++ b/driver.go @@ -124,8 +124,8 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } if mc.cfg.Compress { - mc.reader = NewCompressedReader(&mc.buf, mc) - mc.writer = NewCompressedWriter(mc.writer, mc) + mc.reader = newCompressedReader(&mc.buf, mc) + mc.writer = newCompressedWriter(mc.writer, mc) } if mc.cfg.MaxAllowedPacket > 0 { diff --git a/packets.go b/packets.go index 0303c426d..bb69621be 100644 --- a/packets.go +++ b/packets.go @@ -866,7 +866,8 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { - maxLen := stmt.mc.maxAllowedPacket - 1 + mc := stmt.mc + maxLen := mc.maxAllowedPacket - 1 pktLen := maxLen // After the header (bytes 0-3) follows before the data: @@ -887,8 +888,8 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.sequence = 0 - stmt.mc.compressionSequence = 0 + mc.sequence = 0 + mc.compressionSequence = 0 // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -903,7 +904,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { data[10] = byte(paramID >> 8) // Send CMD packet - err := stmt.mc.writePacket(data[:4+pktLen]) + err := mc.writePacket(data[:4+pktLen]) if err == nil { data = data[pktLen-dataOffset:] continue @@ -913,8 +914,8 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { } // Reset Packet Sequence - stmt.mc.sequence = 0 - stmt.mc.compressionSequence = 0 + mc.sequence = 0 + mc.compressionSequence = 0 return nil } From 3e559a8bfeeadb8378e12adc51d67e5fe9cef339 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Wed, 12 Sep 2018 10:45:45 -0400 Subject: [PATCH 31/33] saving work with SimpleReader present --- benchmark_test.go | 5 +++-- buffer.go | 14 ++++++++++---- compress.go | 19 +++++++++++++++++-- connection.go | 19 +++++++++++++++++-- connection_test.go | 12 ++++++------ driver.go | 13 +++++++------ packets.go | 38 +++++++++++++++++++++++++------------- packets_test.go | 27 ++++++++++++++------------- 8 files changed, 99 insertions(+), 48 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 15bdd9fa1..f661a0997 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -231,9 +231,10 @@ func BenchmarkInterpolation(b *testing.B) { }, maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, - buf: newBuffer(nil), } - mc.reader = &mc.buf + + buf := newBuffer(nil) + mc.reader = newSimpleReader(&buf) args := []driver.Value{ int64(42424242), diff --git a/buffer.go b/buffer.go index 2001feacd..2020397d7 100644 --- a/buffer.go +++ b/buffer.go @@ -21,7 +21,7 @@ const defaultBufSize = 4096 // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. -type buffer struct { +type buffer struct { //PROBLEM: figure this all out better buf []byte nc net.Conn idx int @@ -49,7 +49,7 @@ func (b *buffer) fill(need int) error { // grow buffer if necessary // TODO: let the buffer shrink again at some point // Maybe keep the org buf slice and swap back? - if need > len(b.buf) { + if need > len(b.buf) { //look up what len and cap mean again! // Round up to the next multiple of the default size newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) copy(newBuf, b.buf) @@ -92,6 +92,10 @@ func (b *buffer) fill(need int) error { // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read func (b *buffer) readNext(need int) ([]byte, error) { + if need == -1 { + return b.takeCompleteBuffer() + } + if b.length < need { // refill if err := b.fill(need); err != nil { @@ -110,7 +114,7 @@ func (b *buffer) readNext(need int) ([]byte, error) { // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. func (b *buffer) takeBuffer(length int) []byte { - if b.length > 0 { + if b.length > 0 { //assume its empty return nil } @@ -126,15 +130,17 @@ func (b *buffer) takeBuffer(length int) []byte { return make([]byte, length) } +/* // shortcut which can be used if the requested buffer is guaranteed to be // smaller than defaultBufSize // Only one buffer (total) can be used at a time. func (b *buffer) takeSmallBuffer(length int) []byte { - if b.length == 0 { + if b.length == 0 { //assume its empty return b.buf[:length] } return nil } +*/ // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. diff --git a/compress.go b/compress.go index eaa23fdd5..e4708b63b 100644 --- a/compress.go +++ b/compress.go @@ -11,19 +11,24 @@ const ( ) type compressedReader struct { - buf packetReader + buf *buffer //packetReader bytesBuf []byte mc *mysqlConn zr io.ReadCloser } + +type simpleReader struct { + buf *buffer //packetReader +} + type compressedWriter struct { connWriter io.Writer mc *mysqlConn zw *zlib.Writer } -func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { +func newCompressedReader(buf *buffer, mc *mysqlConn) *compressedReader { return &compressedReader{ buf: buf, bytesBuf: make([]byte, 0), @@ -31,6 +36,12 @@ func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { } } +func newSimpleReader(buf *buffer) *simpleReader { + return &simpleReader{ + buf: buf, + } +} + func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { return &compressedWriter{ connWriter: connWriter, @@ -52,6 +63,10 @@ func (cr *compressedReader) readNext(need int) ([]byte, error) { return data, nil } +func (sr *simpleReader) readNext(need int) ([]byte, error) { + return sr.buf.readNext(need) +} + func (cr *compressedReader) uncompressPacket() error { header, err := cr.buf.readNext(7) // size of compressed header diff --git a/connection.go b/connection.go index cc802fa42..339f59e26 100644 --- a/connection.go +++ b/connection.go @@ -28,7 +28,6 @@ type mysqlContext interface { } type mysqlConn struct { - buf buffer netConn net.Conn reader packetReader writer io.Writer @@ -57,6 +56,18 @@ type packetReader interface { readNext(need int) ([]byte, error) } +/* +type packetReadCloser interface{ + Read(n int) ([]byte, error) + Close() error // PROBLEM: is there a way to do this? +} + +type packetWriteCloser interface{ + Write([]byte) (int, error) + Close() error +} +*/ + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { for param, val := range mc.cfg.Params { @@ -197,7 +208,11 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - buf := mc.buf.takeCompleteBuffer() + //https://stackoverflow.com/questions/29684609/how-to-check-if-an-object-has-a-particular-method + + //reader has &buf which is a packetreader but also always a buffer + buf, _ := mc.reader.readNext(-1) //PROBLEM uncompressed so this works, what if compressed + if buf == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) diff --git a/connection_test.go b/connection_test.go index 187c76116..b85e48260 100644 --- a/connection_test.go +++ b/connection_test.go @@ -15,13 +15,13 @@ import ( func TestInterpolateParams(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, } - mc.reader = &mc.buf + buf := newBuffer(nil) + mc.reader = newSimpleReader(&buf) q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) if err != nil { @@ -36,13 +36,13 @@ func TestInterpolateParams(t *testing.T) { func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, } - mc.reader = &mc.buf + buf := newBuffer(nil) + mc.reader = newSimpleReader(&buf) q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) if err != driver.ErrSkip { @@ -54,14 +54,14 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { // https://github.com/go-sql-driver/mysql/pull/490 func TestInterpolateParamsPlaceholderInString(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, } - mc.reader = &mc.buf + buf := newBuffer(nil) + mc.reader = newSimpleReader(&buf) q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` diff --git a/driver.go b/driver.go index 35f104d9f..083542d94 100644 --- a/driver.go +++ b/driver.go @@ -91,15 +91,16 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { s.startWatcher() } - mc.buf = newBuffer(mc.netConn) + buf := newBuffer(mc.netConn) + + // Set I/O timeouts + buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout // packet reader and writer in handshake are never compressed - mc.reader = &mc.buf + mc.reader = newSimpleReader(&buf) mc.writer = mc.netConn - // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet cipher, err := mc.readInitPacket() @@ -124,7 +125,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } if mc.cfg.Compress { - mc.reader = newCompressedReader(&mc.buf, mc) + mc.reader = newCompressedReader(&buf, mc) mc.writer = newCompressedWriter(mc.writer, mc) } diff --git a/packets.go b/packets.go index bb69621be..f5edc88df 100644 --- a/packets.go +++ b/packets.go @@ -28,7 +28,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte for { // read packet header - data, err := mc.reader.readNext(4) + data, err := mc.reader.readNext(4) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -64,7 +64,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = mc.reader.readNext(pktLen) + data, err = mc.reader.readNext(pktLen) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -283,7 +283,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) + data, _ := mc.reader.readNext(pktLen + 4) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -326,8 +327,12 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { return err } mc.netConn = tlsConn - mc.buf.nc = tlsConn + nc := tlsConn + // make newBuffer with tls conn, clean slate bc handshake + newBuf := newBuffer(nc) + mc.reader = newSimpleReader(&newBuf) + mc.writer = mc.netConn } @@ -373,7 +378,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) + data, _ := mc.reader.readNext(4 + pktLen) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -392,7 +398,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { func (mc *mysqlConn) writeClearAuthPacket() error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) + data, _ := mc.reader.readNext(4 + pktLen) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -415,7 +422,8 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) - data := mc.buf.takeSmallBuffer(4 + pktLen) + data, _ := mc.reader.readNext(4 + pktLen) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -437,7 +445,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { mc.sequence = 0 mc.compressionSequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1) + data, _ := mc.reader.readNext(4+1) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -457,7 +466,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.compressionSequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) + data, _ := mc.reader.readNext(pktLen + 4) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -479,7 +489,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { mc.sequence = 0 mc.compressionSequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1 + 4) + data, _ := mc.reader.readNext(4 + 1 + 4) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -946,9 +957,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { var data []byte if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data, _ = mc.reader.readNext(minPktLen) } else { - data = mc.buf.takeCompleteBuffer() + data, _ = mc.reader.readNext(-1) //how does this work out with compressed? } if data == nil { // can not take the buffer. Something must be wrong with the connection @@ -1127,7 +1138,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - mc.buf.buf = data + readerBuffer := mc.reader.getBuffer() //PROBLEM: what the fuck is going on here?? + readerBuffer.buf = data } pos += len(paramValues) diff --git a/packets_test.go b/packets_test.go index 8f403835a..6d1759c41 100644 --- a/packets_test.go +++ b/packets_test.go @@ -89,12 +89,11 @@ var _ net.Conn = new(mockConn) func TestReadPacketSingleByte(t *testing.T) { conn := new(mockConn) + buf := newBuffer(conn) mc := &mysqlConn{ - buf: newBuffer(conn), + reader: newSimpleReader(&buf), } - mc.reader = &mc.buf - conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 packet, err := mc.readPacket() @@ -111,10 +110,10 @@ func TestReadPacketSingleByte(t *testing.T) { func TestReadPacketWrongSequenceID(t *testing.T) { conn := new(mockConn) + buf:= newBuffer(conn) mc := &mysqlConn{ - buf: newBuffer(conn), + reader: newSimpleReader(&buf), } - mc.reader = &mc.buf // too low sequence id conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -128,7 +127,8 @@ func TestReadPacketWrongSequenceID(t *testing.T) { // reset conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + newBuf := newBuffer(conn) + mc.reader = newSimpleReader(&newBuf) // too high sequence id conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} @@ -140,12 +140,11 @@ func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) + buf:= newBuffer(conn) mc := &mysqlConn{ - buf: newBuffer(conn), + reader : newSimpleReader(&buf), } - mc.reader = &mc.buf - data := make([]byte, maxPacketSize*2+4*3) const pkt2ofs = maxPacketSize + 4 const pkt3ofs = 2 * (maxPacketSize + 4) @@ -247,11 +246,11 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) + buf := newBuffer(conn) mc := &mysqlConn{ - buf: newBuffer(conn), + reader: newSimpleReader(&buf), closech: make(chan struct{}), } - mc.reader = &mc.buf // illegal empty (stand-alone) packet conn.data = []byte{0x00, 0x00, 0x00, 0x00} @@ -264,7 +263,8 @@ func TestReadPacketFail(t *testing.T) { // reset conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + newBuf := newBuffer(conn) + mc.reader = newSimpleReader(&newBuf) // fail to read header conn.closed = true @@ -277,7 +277,8 @@ func TestReadPacketFail(t *testing.T) { conn.closed = false conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + newBuf = newBuffer(conn) + mc.reader = newSimpleReader(&newBuf) // fail to read body conn.maxReads = 1 From 6ceaef67f5e5a28fd265619e93a5348060e2470d Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Wed, 26 Sep 2018 17:18:08 -0400 Subject: [PATCH 32/33] removed buf from mysqlConn --- benchmark_test.go | 6 +++--- buffer.go | 28 ++++++++-------------------- compress.go | 19 ++++--------------- compress_test.go | 4 ++++ connection.go | 18 ++---------------- connection_test.go | 6 +++--- driver.go | 3 +-- packets.go | 46 ++++++++++++++++++++++++---------------------- packets_test.go | 18 +++++++++--------- 9 files changed, 58 insertions(+), 90 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index f661a0997..498627dee 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -205,7 +205,7 @@ func BenchmarkRoundtripBin(b *testing.B) { length = max } test := sample[0:length] - rows := tb.checkRows(stmt.Query(test)) + rows := tb.checkRows(stmt.Query(test)) //run benchmark tests to test that bit of code if !rows.Next() { rows.Close() b.Fatalf("crashed") @@ -232,9 +232,9 @@ func BenchmarkInterpolation(b *testing.B) { maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, } - + buf := newBuffer(nil) - mc.reader = newSimpleReader(&buf) + mc.reader = &buf args := []driver.Value{ int64(42424242), diff --git a/buffer.go b/buffer.go index 2020397d7..82ffe5197 100644 --- a/buffer.go +++ b/buffer.go @@ -21,7 +21,7 @@ const defaultBufSize = 4096 // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. -type buffer struct { //PROBLEM: figure this all out better +type buffer struct { buf []byte nc net.Conn idx int @@ -49,7 +49,7 @@ func (b *buffer) fill(need int) error { // grow buffer if necessary // TODO: let the buffer shrink again at some point // Maybe keep the org buf slice and swap back? - if need > len(b.buf) { //look up what len and cap mean again! + if need > len(b.buf) { // Round up to the next multiple of the default size newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) copy(newBuf, b.buf) @@ -92,10 +92,6 @@ func (b *buffer) fill(need int) error { // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read func (b *buffer) readNext(need int) ([]byte, error) { - if need == -1 { - return b.takeCompleteBuffer() - } - if b.length < need { // refill if err := b.fill(need); err != nil { @@ -113,8 +109,12 @@ func (b *buffer) readNext(need int) ([]byte, error) { // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { - if b.length > 0 { //assume its empty +func (b *buffer) reuseBuffer(length int) []byte { + if length == -1 { + return b.takeCompleteBuffer() + } + + if b.length > 0 { return nil } @@ -130,18 +130,6 @@ func (b *buffer) takeBuffer(length int) []byte { return make([]byte, length) } -/* -// shortcut which can be used if the requested buffer is guaranteed to be -// smaller than defaultBufSize -// Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { - if b.length == 0 { //assume its empty - return b.buf[:length] - } - return nil -} -*/ - // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. // Only one buffer (total) can be used at a time. diff --git a/compress.go b/compress.go index e4708b63b..6c45ebeeb 100644 --- a/compress.go +++ b/compress.go @@ -11,24 +11,19 @@ const ( ) type compressedReader struct { - buf *buffer //packetReader + buf packetReader bytesBuf []byte mc *mysqlConn zr io.ReadCloser } - -type simpleReader struct { - buf *buffer //packetReader -} - type compressedWriter struct { connWriter io.Writer mc *mysqlConn zw *zlib.Writer } -func newCompressedReader(buf *buffer, mc *mysqlConn) *compressedReader { +func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader { return &compressedReader{ buf: buf, bytesBuf: make([]byte, 0), @@ -36,12 +31,6 @@ func newCompressedReader(buf *buffer, mc *mysqlConn) *compressedReader { } } -func newSimpleReader(buf *buffer) *simpleReader { - return &simpleReader{ - buf: buf, - } -} - func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter { return &compressedWriter{ connWriter: connWriter, @@ -63,8 +52,8 @@ func (cr *compressedReader) readNext(need int) ([]byte, error) { return data, nil } -func (sr *simpleReader) readNext(need int) ([]byte, error) { - return sr.buf.readNext(need) +func (cr *compressedReader) reuseBuffer(length int) []byte { + return cr.buf.reuseBuffer(length) } func (cr *compressedReader) uncompressPacket() error { diff --git a/compress_test.go b/compress_test.go index ce442ed37..0e98599c3 100644 --- a/compress_test.go +++ b/compress_test.go @@ -39,6 +39,10 @@ func (mb *mockBuf) readNext(need int) ([]byte, error) { return data, nil } +func (mb *mockBuf) reuseBuffer(length int) []byte { + return make([]byte, length) //just give them a new buffer +} + // compressHelper compresses uncompressedPacket and checks state variables func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { // get status variables diff --git a/connection.go b/connection.go index 339f59e26..c17e62ca8 100644 --- a/connection.go +++ b/connection.go @@ -54,20 +54,9 @@ type mysqlConn struct { type packetReader interface { readNext(need int) ([]byte, error) + reuseBuffer(length int) []byte } -/* -type packetReadCloser interface{ - Read(n int) ([]byte, error) - Close() error // PROBLEM: is there a way to do this? -} - -type packetWriteCloser interface{ - Write([]byte) (int, error) - Close() error -} -*/ - // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { for param, val := range mc.cfg.Params { @@ -208,10 +197,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - //https://stackoverflow.com/questions/29684609/how-to-check-if-an-object-has-a-particular-method - - //reader has &buf which is a packetreader but also always a buffer - buf, _ := mc.reader.readNext(-1) //PROBLEM uncompressed so this works, what if compressed + buf := mc.reader.reuseBuffer(-1) if buf == nil { // can not take the buffer. Something must be wrong with the connection diff --git a/connection_test.go b/connection_test.go index b85e48260..ac750b574 100644 --- a/connection_test.go +++ b/connection_test.go @@ -21,7 +21,7 @@ func TestInterpolateParams(t *testing.T) { }, } buf := newBuffer(nil) - mc.reader = newSimpleReader(&buf) + mc.reader = &buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) if err != nil { @@ -42,7 +42,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { }, } buf := newBuffer(nil) - mc.reader = newSimpleReader(&buf) + mc.reader = &buf q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) if err != driver.ErrSkip { @@ -61,7 +61,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { } buf := newBuffer(nil) - mc.reader = newSimpleReader(&buf) + mc.reader = &buf q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` diff --git a/driver.go b/driver.go index 083542d94..ddde423da 100644 --- a/driver.go +++ b/driver.go @@ -98,10 +98,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.writeTimeout = mc.cfg.WriteTimeout // packet reader and writer in handshake are never compressed - mc.reader = newSimpleReader(&buf) + mc.reader = &buf mc.writer = mc.netConn - // Reading Handshake Initialization Packet cipher, err := mc.readInitPacket() if err != nil { diff --git a/packets.go b/packets.go index f5edc88df..ceab20ef8 100644 --- a/packets.go +++ b/packets.go @@ -28,7 +28,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte for { // read packet header - data, err := mc.reader.readNext(4) + data, err := mc.reader.readNext(4) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -64,7 +64,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = mc.reader.readNext(pktLen) + data, err = mc.reader.readNext(pktLen) if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr @@ -283,8 +283,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // Calculate packet length and get buffer with that size - data, _ := mc.reader.readNext(pktLen + 4) - + data := mc.reader.reuseBuffer(pktLen + 4) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -329,10 +329,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { mc.netConn = tlsConn nc := tlsConn - // make newBuffer with tls conn, clean slate bc handshake newBuf := newBuffer(nc) - mc.reader = newSimpleReader(&newBuf) - + mc.reader = &newBuf + mc.writer = mc.netConn } @@ -378,7 +377,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) + 1 - data, _ := mc.reader.readNext(4 + pktLen) + data := mc.reader.reuseBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection @@ -398,8 +397,8 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { func (mc *mysqlConn) writeClearAuthPacket() error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 - data, _ := mc.reader.readNext(4 + pktLen) - + data := mc.reader.reuseBuffer(4 + pktLen) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -422,8 +421,8 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) - data, _ := mc.reader.readNext(4 + pktLen) - + data := mc.reader.reuseBuffer(4 + pktLen) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -445,8 +444,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { mc.sequence = 0 mc.compressionSequence = 0 - data, _ := mc.reader.readNext(4+1) - + data := mc.reader.reuseBuffer(4 + 1) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -466,8 +465,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.compressionSequence = 0 pktLen := 1 + len(arg) - data, _ := mc.reader.readNext(pktLen + 4) - + data := mc.reader.reuseBuffer(pktLen + 4) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -489,8 +488,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { mc.sequence = 0 mc.compressionSequence = 0 - data, _ := mc.reader.readNext(4 + 1 + 4) - + data := mc.reader.reuseBuffer(4 + 1 + 4) + if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) @@ -957,9 +956,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { var data []byte if len(args) == 0 { - data, _ = mc.reader.readNext(minPktLen) + data = mc.reader.reuseBuffer(minPktLen) + } else { - data, _ = mc.reader.readNext(-1) //how does this work out with compressed? + data = mc.reader.reuseBuffer(-1) } if data == nil { // can not take the buffer. Something must be wrong with the connection @@ -1138,8 +1138,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - readerBuffer := mc.reader.getBuffer() //PROBLEM: what the fuck is going on here?? - readerBuffer.buf = data + + bufBuf := mc.reader.reuseBuffer(-1) + bufBuf = data + fmt.Println(bufBuf) //dont know how to make it compile w/o some op here on bufBuf } pos += len(paramValues) diff --git a/packets_test.go b/packets_test.go index 6d1759c41..2f98d8c24 100644 --- a/packets_test.go +++ b/packets_test.go @@ -91,7 +91,7 @@ func TestReadPacketSingleByte(t *testing.T) { conn := new(mockConn) buf := newBuffer(conn) mc := &mysqlConn{ - reader: newSimpleReader(&buf), + reader: &buf, } conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -110,9 +110,9 @@ func TestReadPacketSingleByte(t *testing.T) { func TestReadPacketWrongSequenceID(t *testing.T) { conn := new(mockConn) - buf:= newBuffer(conn) + buf := newBuffer(conn) mc := &mysqlConn{ - reader: newSimpleReader(&buf), + reader: &buf, } // too low sequence id @@ -128,7 +128,7 @@ func TestReadPacketWrongSequenceID(t *testing.T) { conn.reads = 0 mc.sequence = 0 newBuf := newBuffer(conn) - mc.reader = newSimpleReader(&newBuf) + mc.reader = &newBuf // too high sequence id conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} @@ -140,9 +140,9 @@ func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) - buf:= newBuffer(conn) + buf := newBuffer(conn) mc := &mysqlConn{ - reader : newSimpleReader(&buf), + reader: &buf, } data := make([]byte, maxPacketSize*2+4*3) @@ -248,7 +248,7 @@ func TestReadPacketFail(t *testing.T) { conn := new(mockConn) buf := newBuffer(conn) mc := &mysqlConn{ - reader: newSimpleReader(&buf), + reader: &buf, closech: make(chan struct{}), } @@ -264,7 +264,7 @@ func TestReadPacketFail(t *testing.T) { conn.reads = 0 mc.sequence = 0 newBuf := newBuffer(conn) - mc.reader = newSimpleReader(&newBuf) + mc.reader = &newBuf // fail to read header conn.closed = true @@ -278,7 +278,7 @@ func TestReadPacketFail(t *testing.T) { conn.reads = 0 mc.sequence = 0 newBuf = newBuffer(conn) - mc.reader = newSimpleReader(&newBuf) + mc.reader = &newBuf // fail to read body conn.maxReads = 1 From f617170b6964683d9a60f78e4ef8b1f2658c3dc3 Mon Sep 17 00:00:00 2001 From: Brigitte Lamarche Date: Mon, 8 Oct 2018 10:10:31 -0400 Subject: [PATCH 33/33] removed comment --- benchmark_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark_test.go b/benchmark_test.go index 498627dee..af7b3971c 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -205,7 +205,7 @@ func BenchmarkRoundtripBin(b *testing.B) { length = max } test := sample[0:length] - rows := tb.checkRows(stmt.Query(test)) //run benchmark tests to test that bit of code + rows := tb.checkRows(stmt.Query(test)) if !rows.Next() { rows.Close() b.Fatalf("crashed")