diff --git a/client/req.go b/client/req.go index dde03e711..87ad50606 100644 --- a/client/req.go +++ b/client/req.go @@ -1,5 +1,9 @@ package client +import ( + "github.com/siddontang/go-mysql/utils" +) + func (c *Conn) writeCommand(command byte) error { c.ResetSequence() @@ -16,28 +20,20 @@ func (c *Conn) writeCommandBuf(command byte, arg []byte) error { c.ResetSequence() length := len(arg) + 1 - - data := make([]byte, length+4) - + data := utils.ByteSliceGet(length + 4) data[4] = command copy(data[5:], arg) - return c.WritePacket(data) -} - -func (c *Conn) writeCommandStr(command byte, arg string) error { - c.ResetSequence() - - length := len(arg) + 1 - - data := make([]byte, length+4) + err := c.WritePacket(data) - data[4] = command + utils.ByteSlicePut(data) - copy(data[5:], arg) + return err +} - return c.WritePacket(data) +func (c *Conn) writeCommandStr(command byte, arg string) error { + return c.writeCommandBuf(command, utils.StringToByteSlice(arg)) } func (c *Conn) writeCommandUint32(command byte, arg uint32) error { diff --git a/client/resp.go b/client/resp.go index f6a445201..712f06927 100644 --- a/client/resp.go +++ b/client/resp.go @@ -1,15 +1,15 @@ package client import ( - "encoding/binary" - "bytes" "crypto/rsa" "crypto/x509" + "encoding/binary" "encoding/pem" "github.com/pingcap/errors" . "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/utils" "github.com/siddontang/go/hack" ) @@ -212,31 +212,25 @@ func (c *Conn) readOK() (*Result, error) { } func (c *Conn) readResult(binary bool) (*Result, error) { - data, err := c.ReadPacket() + firstPkgBuf, err := c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0]) + defer utils.ByteSlicePut(firstPkgBuf) + if err != nil { return nil, errors.Trace(err) } - if data[0] == OK_HEADER { - return c.handleOKPacket(data) - } else if data[0] == ERR_HEADER { - return nil, c.handleErrorPacket(data) - } else if data[0] == LocalInFile_HEADER { + if firstPkgBuf[0] == OK_HEADER { + return c.handleOKPacket(firstPkgBuf) + } else if firstPkgBuf[0] == ERR_HEADER { + return nil, c.handleErrorPacket(append([]byte{}, firstPkgBuf...)) + } else if firstPkgBuf[0] == LocalInFile_HEADER { return nil, ErrMalformPacket } - return c.readResultset(data, binary) + return c.readResultset(firstPkgBuf, binary) } func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) { - result := &Result{ - Status: 0, - InsertId: 0, - AffectedRows: 0, - - Resultset: &Resultset{}, - } - // column count count, _, n := LengthEncodedInt(data) @@ -244,8 +238,9 @@ func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) { return nil, ErrMalformPacket } - result.Fields = make([]*Field, count) - result.FieldNames = make(map[string]int, count) + result := &Result{ + Resultset: NewResultset(int(count)), + } if err := c.readResultColumns(result); err != nil { return nil, errors.Trace(err) @@ -263,10 +258,12 @@ func (c *Conn) readResultColumns(result *Result) (err error) { var data []byte for { - data, err = c.ReadPacket() + rawPkgLen := len(result.RawPkg) + result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg) if err != nil { return } + data = result.RawPkg[rawPkgLen:] // EOF Packet if c.isEOFPacket(data) { @@ -284,7 +281,10 @@ func (c *Conn) readResultColumns(result *Result) (err error) { return } - result.Fields[i], err = FieldData(data).Parse() + if result.Fields[i] == nil { + result.Fields[i] = &Field{} + } + err = result.Fields[i].Parse(data) if err != nil { return } @@ -299,11 +299,12 @@ func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) { var data []byte for { - data, err = c.ReadPacket() - + rawPkgLen := len(result.RawPkg) + result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg) if err != nil { return } + data = result.RawPkg[rawPkgLen:] // EOF Packet if c.isEOFPacket(data) { @@ -324,10 +325,14 @@ func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) { result.RowDatas = append(result.RowDatas, data) } - result.Values = make([][]interface{}, len(result.RowDatas)) + if cap(result.Values) < len(result.RowDatas) { + result.Values = make([][]FieldValue, len(result.RowDatas)) + } else { + result.Values = result.Values[:len(result.RowDatas)] + } for i := range result.Values { - result.Values[i], err = result.RowDatas[i].Parse(result.Fields, isBinary) + result.Values[i], err = result.RowDatas[i].Parse(result.Fields, isBinary, result.Values[i]) if err != nil { return errors.Trace(err) diff --git a/driver/dirver_test.go b/driver/driver_test.go similarity index 100% rename from driver/dirver_test.go rename to driver/driver_test.go diff --git a/mysql/field.go b/mysql/field.go index 891f00b15..84b36067e 100644 --- a/mysql/field.go +++ b/mysql/field.go @@ -2,6 +2,8 @@ package mysql import ( "encoding/binary" + + "github.com/siddontang/go-mysql/utils" ) type FieldData []byte @@ -23,9 +25,23 @@ type Field struct { DefaultValue []byte } -func (p FieldData) Parse() (f *Field, err error) { - f = new(Field) +type FieldValueType uint8 + +type FieldValue struct { + Type FieldValueType + value uint64 // Also for int64 and float64 + str []byte +} + +const ( + FieldValueTypeNull = iota + FieldValueTypeUnsigned + FieldValueTypeSigned + FieldValueTypeFloat + FieldValueTypeString +) +func (f *Field) Parse(p FieldData) (err error) { f.Data = p var n int @@ -117,6 +133,14 @@ func (p FieldData) Parse() (f *Field, err error) { return } +func (p FieldData) Parse() (f *Field, err error) { + f = new(Field) + if err = f.Parse(p); err != nil { + return nil, err + } + return f, nil +} + func (f *Field) Dump() []byte { if f == nil { f = &Field{} @@ -155,3 +179,34 @@ func (f *Field) Dump() []byte { return data } + +func (fv *FieldValue) AsUint64() uint64 { + return fv.value +} + +func (fv *FieldValue) AsInt64() int64 { + return utils.Uint64ToInt64(fv.value) +} + +func (fv *FieldValue) AsFloat64() float64 { + return utils.Uint64ToFloat64(fv.value) +} + +func (fv *FieldValue) AsString() []byte { + return fv.str +} + +func (fv *FieldValue) Value() interface{} { + switch fv.Type { + case FieldValueTypeUnsigned: + return fv.AsUint64() + case FieldValueTypeSigned: + return fv.AsInt64() + case FieldValueTypeFloat: + return fv.AsFloat64() + case FieldValueTypeString: + return fv.AsString() + default: // FieldValueTypeNull + return nil + } +} diff --git a/mysql/result.go b/mysql/result.go index d6c80e422..797a4af75 100644 --- a/mysql/result.go +++ b/mysql/result.go @@ -12,3 +12,10 @@ type Result struct { type Executer interface { Execute(query string, args ...interface{}) (*Result, error) } + +func (r *Result) Close() { + if r.Resultset != nil { + r.Resultset.returnToPool() + r.Resultset = nil + } +} diff --git a/mysql/resultset.go b/mysql/resultset.go index 16fc70ac4..2da6e3b5f 100644 --- a/mysql/resultset.go +++ b/mysql/resultset.go @@ -3,227 +3,64 @@ package mysql import ( "fmt" "strconv" + "sync" "github.com/pingcap/errors" "github.com/siddontang/go/hack" ) -type RowData []byte - -func (p RowData) Parse(f []*Field, binary bool) ([]interface{}, error) { - if binary { - return p.ParseBinary(f) - } else { - return p.ParseText(f) - } -} - -func (p RowData) ParseText(f []*Field) ([]interface{}, error) { - data := make([]interface{}, len(f)) +type Resultset struct { + Fields []*Field + FieldNames map[string]int + Values [][]FieldValue - var err error - var v []byte - var isNull bool - var pos int = 0 - var n int = 0 + RawPkg []byte - for i := range f { - v, isNull, n, err = LengthEncodedString(p[pos:]) - if err != nil { - return nil, errors.Trace(err) - } + RowDatas []RowData +} - pos += n - - if isNull { - data[i] = nil - } else { - isUnsigned := f[i].Flag&UNSIGNED_FLAG != 0 - - switch f[i].Type { - case MYSQL_TYPE_TINY, MYSQL_TYPE_SHORT, MYSQL_TYPE_INT24, - MYSQL_TYPE_LONGLONG, MYSQL_TYPE_LONG, MYSQL_TYPE_YEAR: - if isUnsigned { - data[i], err = strconv.ParseUint(string(v), 10, 64) - } else { - data[i], err = strconv.ParseInt(string(v), 10, 64) - } - case MYSQL_TYPE_FLOAT, MYSQL_TYPE_DOUBLE: - data[i], err = strconv.ParseFloat(string(v), 64) - default: - data[i] = v - } - - if err != nil { - return nil, errors.Trace(err) - } - } +var ( + resultsetPool = sync.Pool{ + New: func() interface{} { + return &Resultset{} + }, } +) - return data, nil +func NewResultset(resultsetCount int) *Resultset { + r := resultsetPool.Get().(*Resultset) + r.reset(resultsetCount) + return r } -// ParseBinary parses the binary format of data -// see https://dev.mysql.com/doc/internals/en/binary-protocol-value.html -func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) { - data := make([]interface{}, len(f)) - - if p[0] != OK_HEADER { - return nil, ErrMalformPacket - } - - pos := 1 + ((len(f) + 7 + 2) >> 3) +func (r *Resultset) returnToPool() { + resultsetPool.Put(r) +} - nullBitmap := p[1:pos] +func (r *Resultset) reset(count int) { + r.RawPkg = r.RawPkg[:0] - var isNull bool - var n int - var err error - var v []byte - for i := range data { - if nullBitmap[(i+2)/8]&(1<<(uint(i+2)%8)) > 0 { - data[i] = nil - continue - } + r.Fields = r.Fields[:0] + r.Values = r.Values[:0] + r.RowDatas = r.RowDatas[:0] - isUnsigned := f[i].Flag&UNSIGNED_FLAG != 0 - - switch f[i].Type { - case MYSQL_TYPE_NULL: - data[i] = nil - continue - - case MYSQL_TYPE_TINY: - if isUnsigned { - data[i] = ParseBinaryUint8(p[pos : pos+1]) - } else { - data[i] = ParseBinaryInt8(p[pos : pos+1]) - } - pos++ - continue - - case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR: - if isUnsigned { - data[i] = ParseBinaryUint16(p[pos : pos+2]) - } else { - data[i] = ParseBinaryInt16(p[pos : pos+2]) - } - pos += 2 - continue - - case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG: - if isUnsigned { - data[i] = ParseBinaryUint32(p[pos : pos+4]) - } else { - data[i] = ParseBinaryInt32(p[pos : pos+4]) - } - pos += 4 - continue - - case MYSQL_TYPE_LONGLONG: - if isUnsigned { - data[i] = ParseBinaryUint64(p[pos : pos+8]) - } else { - data[i] = ParseBinaryInt64(p[pos : pos+8]) - } - pos += 8 - continue - - case MYSQL_TYPE_FLOAT: - data[i] = ParseBinaryFloat32(p[pos : pos+4]) - pos += 4 - continue - - case MYSQL_TYPE_DOUBLE: - data[i] = ParseBinaryFloat64(p[pos : pos+8]) - pos += 8 - continue - - case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR, - MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB, - MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB, - MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY: - v, isNull, n, err = LengthEncodedString(p[pos:]) - pos += n - if err != nil { - return nil, errors.Trace(err) - } - - if !isNull { - data[i] = v - continue - } else { - data[i] = nil - continue - } - case MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE: - var num uint64 - num, isNull, n = LengthEncodedInt(p[pos:]) - - pos += n - - if isNull { - data[i] = nil - continue - } - - data[i], err = FormatBinaryDate(int(num), p[pos:]) - pos += int(num) - - if err != nil { - return nil, errors.Trace(err) - } - - case MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME: - var num uint64 - num, isNull, n = LengthEncodedInt(p[pos:]) - - pos += n - - if isNull { - data[i] = nil - continue - } - - data[i], err = FormatBinaryDateTime(int(num), p[pos:]) - pos += int(num) - - if err != nil { - return nil, errors.Trace(err) - } - - case MYSQL_TYPE_TIME: - var num uint64 - num, isNull, n = LengthEncodedInt(p[pos:]) - - pos += n - - if isNull { - data[i] = nil - continue - } - - data[i], err = FormatBinaryTime(int(num), p[pos:]) - pos += int(num) - - if err != nil { - return nil, errors.Trace(err) - } - - default: - return nil, errors.Errorf("Stmt Unknown FieldType %d %s", f[i].Type, f[i].Name) + if r.FieldNames != nil { + for k := range r.FieldNames { + delete(r.FieldNames, k) } + } else { + r.FieldNames = make(map[string]int) } - return data, nil -} - -type Resultset struct { - Fields []*Field - FieldNames map[string]int - Values [][]interface{} + if count == 0 { + return + } - RowDatas []RowData + if cap(r.Fields) < count { + r.Fields = make([]*Field, count) + } else { + r.Fields = r.Fields[:count] + } } func (r *Resultset) RowNumber() int { @@ -243,7 +80,7 @@ func (r *Resultset) GetValue(row, column int) (interface{}, error) { return nil, errors.Errorf("invalid column index %d", column) } - return r.Values[row][column], nil + return r.Values[row][column].Value(), nil } func (r *Resultset) NameIndex(name string) (int, error) { diff --git a/mysql/rowdata.go b/mysql/rowdata.go new file mode 100644 index 000000000..f5a493f1b --- /dev/null +++ b/mysql/rowdata.go @@ -0,0 +1,259 @@ +package mysql + +import ( + "strconv" + + "github.com/pingcap/errors" + "github.com/siddontang/go-mysql/utils" +) + +type RowData []byte + +func (p RowData) Parse(f []*Field, binary bool, dst []FieldValue) ([]FieldValue, error) { + if binary { + return p.ParseBinary(f, dst) + } else { + return p.ParseText(f, dst) + } +} + +func (p RowData) ParseText(f []*Field, dst []FieldValue) ([]FieldValue, error) { + for len(dst) < len(f) { + dst = append(dst, FieldValue{}) + } + data := dst[:len(f)] + + var err error + var v []byte + var isNull bool + var pos int = 0 + var n int = 0 + + for i := range f { + v, isNull, n, err = LengthEncodedString(p[pos:]) + if err != nil { + return nil, errors.Trace(err) + } + + pos += n + + if isNull { + data[i].Type = FieldValueTypeNull + } else { + isUnsigned := f[i].Flag&UNSIGNED_FLAG != 0 + + switch f[i].Type { + case MYSQL_TYPE_TINY, MYSQL_TYPE_SHORT, MYSQL_TYPE_INT24, + MYSQL_TYPE_LONGLONG, MYSQL_TYPE_LONG, MYSQL_TYPE_YEAR: + if isUnsigned { + var val uint64 + data[i].Type = FieldValueTypeUnsigned + val, err = strconv.ParseUint(utils.ByteSliceToString(v), 10, 64) + data[i].value = val + } else { + var val int64 + data[i].Type = FieldValueTypeSigned + val, err = strconv.ParseInt(utils.ByteSliceToString(v), 10, 64) + data[i].value = utils.Int64ToUint64(val) + } + case MYSQL_TYPE_FLOAT, MYSQL_TYPE_DOUBLE: + var val float64 + data[i].Type = FieldValueTypeFloat + val, err = strconv.ParseFloat(utils.ByteSliceToString(v), 64) + data[i].value = utils.Float64ToUint64(val) + default: + data[i].Type = FieldValueTypeString + data[i].str = append(data[i].str[:0], v...) + } + + if err != nil { + return nil, errors.Trace(err) + } + } + } + + return data, nil +} + +// ParseBinary parses the binary format of data +// see https://dev.mysql.com/doc/internals/en/binary-protocol-value.html +func (p RowData) ParseBinary(f []*Field, dst []FieldValue) ([]FieldValue, error) { + for len(dst) < len(f) { + dst = append(dst, FieldValue{}) + } + data := dst[:len(f)] + + if p[0] != OK_HEADER { + return nil, ErrMalformPacket + } + + pos := 1 + ((len(f) + 7 + 2) >> 3) + + nullBitmap := p[1:pos] + + var isNull bool + var n int + var err error + var v []byte + for i := range data { + if nullBitmap[(i+2)/8]&(1<<(uint(i+2)%8)) > 0 { + data[i].Type = FieldValueTypeNull + continue + } + + isUnsigned := f[i].Flag&UNSIGNED_FLAG != 0 + + switch f[i].Type { + case MYSQL_TYPE_NULL: + data[i].Type = FieldValueTypeNull + continue + + case MYSQL_TYPE_TINY: + if isUnsigned { + v := ParseBinaryUint8(p[pos : pos+1]) + data[i].Type = FieldValueTypeUnsigned + data[i].value = uint64(v) + } else { + v := ParseBinaryInt8(p[pos : pos+1]) + data[i].Type = FieldValueTypeSigned + data[i].value = utils.Int64ToUint64(int64(v)) + } + pos++ + continue + + case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR: + if isUnsigned { + v := ParseBinaryUint16(p[pos : pos+2]) + data[i].Type = FieldValueTypeUnsigned + data[i].value = uint64(v) + } else { + v := ParseBinaryInt16(p[pos : pos+2]) + data[i].Type = FieldValueTypeSigned + data[i].value = utils.Int64ToUint64(int64(v)) + } + pos += 2 + continue + + case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG: + if isUnsigned { + v := ParseBinaryUint32(p[pos : pos+4]) + data[i].Type = FieldValueTypeUnsigned + data[i].value = uint64(v) + } else { + v := ParseBinaryInt32(p[pos : pos+4]) + data[i].Type = FieldValueTypeSigned + data[i].value = utils.Int64ToUint64(int64(v)) + } + pos += 4 + continue + + case MYSQL_TYPE_LONGLONG: + if isUnsigned { + v := ParseBinaryUint64(p[pos : pos+8]) + data[i].Type = FieldValueTypeUnsigned + data[i].value = v + } else { + v := ParseBinaryInt64(p[pos : pos+8]) + data[i].Type = FieldValueTypeSigned + data[i].value = utils.Int64ToUint64(v) + } + pos += 8 + continue + + case MYSQL_TYPE_FLOAT: + v := ParseBinaryFloat32(p[pos : pos+4]) + data[i].Type = FieldValueTypeFloat + data[i].value = utils.Float64ToUint64(float64(v)) + pos += 4 + continue + + case MYSQL_TYPE_DOUBLE: + v := ParseBinaryFloat64(p[pos : pos+8]) + data[i].Type = FieldValueTypeFloat + data[i].value = utils.Float64ToUint64(v) + pos += 8 + continue + + case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR, + MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB, + MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB, + MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY: + v, isNull, n, err = LengthEncodedString(p[pos:]) + pos += n + if err != nil { + return nil, errors.Trace(err) + } + + if !isNull { + data[i].Type = FieldValueTypeString + data[i].str = append(data[i].str[:0], v...) + continue + } else { + data[i].Type = FieldValueTypeNull + continue + } + + case MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE: + var num uint64 + num, isNull, n = LengthEncodedInt(p[pos:]) + + pos += n + + if isNull { + data[i].Type = FieldValueTypeNull + continue + } + + data[i].Type = FieldValueTypeString + data[i].str, err = FormatBinaryDate(int(num), p[pos:]) + pos += int(num) + + if err != nil { + return nil, errors.Trace(err) + } + + case MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME: + var num uint64 + num, isNull, n = LengthEncodedInt(p[pos:]) + + pos += n + + if isNull { + data[i].Type = FieldValueTypeNull + continue + } + + data[i].Type = FieldValueTypeString + data[i].str, err = FormatBinaryDateTime(int(num), p[pos:]) + pos += int(num) + + if err != nil { + return nil, errors.Trace(err) + } + + case MYSQL_TYPE_TIME: + var num uint64 + num, isNull, n = LengthEncodedInt(p[pos:]) + + pos += n + + if isNull { + data[i].Type = FieldValueTypeNull + continue + } + + data[i].Type = FieldValueTypeString + data[i].str, err = FormatBinaryTime(int(num), p[pos:]) + pos += int(num) + + if err != nil { + return nil, errors.Trace(err) + } + + default: + return nil, errors.Errorf("Stmt Unknown FieldType %d %s", f[i].Type, f[i].Name) + } + } + + return data, nil +} diff --git a/mysql/util.go b/mysql/util.go index 5ab653227..b9e0501e7 100644 --- a/mysql/util.go +++ b/mysql/util.go @@ -2,15 +2,15 @@ package mysql import ( "crypto/rand" + "crypto/rsa" "crypto/sha1" + "crypto/sha256" "encoding/binary" "fmt" "io" "runtime" "strings" - "crypto/rsa" - "crypto/sha256" "github.com/pingcap/errors" "github.com/siddontang/go/hack" ) diff --git a/packet/conn.go b/packet/conn.go index 02373ad03..13ed29ed1 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -15,6 +15,7 @@ import ( "github.com/pingcap/errors" . "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/utils" ) type BufPool struct { @@ -53,6 +54,10 @@ type Conn struct { br *bufio.Reader reader io.Reader + copyNBuf []byte + + header [4]byte + Sequence uint8 } @@ -64,6 +69,8 @@ func NewConn(conn net.Conn) *Conn { c.br = bufio.NewReaderSize(c, 65536) // 64kb c.reader = c.br + c.copyNBuf = make([]byte, 16*1024) + return c } @@ -74,31 +81,60 @@ func NewTLSConn(conn net.Conn) *Conn { c.bufPool = NewBufPool() c.reader = c + c.copyNBuf = make([]byte, 16*1024) + return c } func (c *Conn) ReadPacket() ([]byte, error) { + return c.ReadPacketReuseMem(nil) +} + +func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { // Here we use `sync.Pool` to avoid allocate/destroy buffers frequently. - buf := c.bufPool.Get() - defer c.bufPool.Return(buf) + buf := utils.BytesBufferGet() + defer utils.BytesBufferPut(buf) if err := c.ReadPacketTo(buf); err != nil { return nil, errors.Trace(err) } else { - result := append([]byte{}, buf.Bytes()...) + result := append(dst, buf.Bytes()...) return result, nil } } -func (c *Conn) ReadPacketTo(w io.Writer) error { - header := []byte{0, 0, 0, 0} +func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) { + for n > 0 { + bcap := cap(c.copyNBuf) + if int64(bcap) > n { + bcap = int(n) + } + buf := c.copyNBuf[:bcap] + + rd, err := io.ReadAtLeast(src, buf, bcap) + n -= int64(rd) - if _, err := io.ReadFull(c.reader, header); err != nil { + if err != nil { + return written, errors.Trace(err) + } + + wr, err := dst.Write(buf) + written += int64(wr) + if err != nil { + return written, errors.Trace(err) + } + } + + return written, nil +} + +func (c *Conn) ReadPacketTo(w io.Writer) error { + if _, err := io.ReadFull(c.reader, c.header[:4]); err != nil { return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err) } - length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - sequence := uint8(header[3]) + length := int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16) + sequence := uint8(c.header[3]) if sequence != c.Sequence { return errors.Errorf("invalid sequence %d != %d", sequence, c.Sequence) @@ -111,7 +147,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { buf.Grow(length) } - if n, err := io.CopyN(w, c.reader, int64(length)); err != nil { + if n, err := c.copyN(w, c.reader, int64(length)); err != nil { return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length) } else if n != int64(length) { return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length) diff --git a/utils/byte_slice_pool.go b/utils/byte_slice_pool.go new file mode 100644 index 000000000..be0f1a21d --- /dev/null +++ b/utils/byte_slice_pool.go @@ -0,0 +1,36 @@ +package utils + +import "sync" + +var ( + byteSlicePool = sync.Pool{ + New: func() interface{} { + return []byte{} + }, + } + byteSliceChan = make(chan []byte, 10) +) + +func ByteSliceGet(length int) (data []byte) { + select { + case data = <-byteSliceChan: + default: + data = byteSlicePool.Get().([]byte)[:0] + } + + if cap(data) < length { + data = make([]byte, length) + } else { + data = data[:length] + } + + return data +} + +func ByteSlicePut(data []byte) { + select { + case byteSliceChan <- data: + default: + byteSlicePool.Put(data) + } +} diff --git a/utils/bytes_buffer_pool.go b/utils/bytes_buffer_pool.go new file mode 100644 index 000000000..a1ca8707d --- /dev/null +++ b/utils/bytes_buffer_pool.go @@ -0,0 +1,35 @@ +package utils + +import ( + "bytes" + "sync" +) + +var ( + bytesBufferPool = sync.Pool{ + New: func() interface{} { + return &bytes.Buffer{} + }, + } + bytesBufferChan = make(chan *bytes.Buffer, 10) +) + +func BytesBufferGet() (data *bytes.Buffer) { + select { + case data = <-bytesBufferChan: + default: + data = bytesBufferPool.Get().(*bytes.Buffer) + } + + data.Reset() + + return data +} + +func BytesBufferPut(data *bytes.Buffer) { + select { + case bytesBufferChan <- data: + default: + bytesBufferPool.Put(data) + } +} diff --git a/utils/zeroalloc.go b/utils/zeroalloc.go new file mode 100644 index 000000000..ca3798c0d --- /dev/null +++ b/utils/zeroalloc.go @@ -0,0 +1,27 @@ +package utils + +import "unsafe" + +func StringToByteSlice(s string) []byte { + return *(*[]byte)(unsafe.Pointer(&s)) +} + +func ByteSliceToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +func Uint64ToInt64(val uint64) int64 { + return *(*int64)(unsafe.Pointer(&val)) +} + +func Uint64ToFloat64(val uint64) float64 { + return *(*float64)(unsafe.Pointer(&val)) +} + +func Int64ToUint64(val int64) uint64 { + return *(*uint64)(unsafe.Pointer(&val)) +} + +func Float64ToUint64(val float64) uint64 { + return *(*uint64)(unsafe.Pointer(&val)) +}