diff --git a/README.md b/README.md index 5e4ca30ea..393a00088 100644 --- a/README.md +++ b/README.md @@ -237,8 +237,8 @@ err := conn.ExecuteSelectStreaming(`select id, name from table LIMIT 100500`, &r // Copy it if you need. // ... } - return false, nil -}) + return nil +}, nil) // ... ``` diff --git a/client/conn.go b/client/conn.go index 02c10cffc..7f422f4df 100644 --- a/client/conn.go +++ b/client/conn.go @@ -39,6 +39,9 @@ type Conn struct { // This function will be called for every row in resultset from ExecuteSelectStreaming. type SelectPerRowCallback func(row []FieldValue) error +// This function will be called once per result from ExecuteSelectStreaming +type SelectPerResultCallback func(result *Result) error + func getNetProto(addr string) string { proto := "tcp" if strings.Contains(addr, "/") { @@ -183,6 +186,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { // ExecuteSelectStreaming will call perRowCallback for every row in resultset // WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields. +// When given, perResultCallback will be called once per result // // ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving. // @@ -193,14 +197,14 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { // // Use the row as you want. // // You must not save FieldValue.AsString() value after this callback is done. Copy it if you need. // return nil -// }) +// }, nil) // -func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback) error { +func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { if err := c.writeCommandStr(COM_QUERY, command); err != nil { return errors.Trace(err) } - return c.readResultStreaming(false, result, perRowCallback) + return c.readResultStreaming(false, result, perRowCallback, perResultCallback) } func (c *Conn) Begin() error { diff --git a/client/resp.go b/client/resp.go index 06abfdb85..e231805d1 100644 --- a/client/resp.go +++ b/client/resp.go @@ -233,7 +233,7 @@ func (c *Conn) readResult(binary bool) (*Result, error) { return c.readResultset(firstPkgBuf, binary) } -func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback) error { +func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error { firstPkgBuf, err := c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0]) defer utils.ByteSlicePut(firstPkgBuf) @@ -267,7 +267,7 @@ func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectP return ErrMalformPacket } - return c.readResultsetStreaming(firstPkgBuf, binary, result, perRowCb) + return c.readResultsetStreaming(firstPkgBuf, binary, result, perRowCb, perResCb) } func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) { @@ -293,7 +293,7 @@ func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) { return result, nil } -func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, perRowCb SelectPerRowCallback) error { +func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error { columnCount, _, n := LengthEncodedInt(data) if n-len(data) != 0 { @@ -307,14 +307,26 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, result.Reset(int(columnCount)) } + // this is a streaming resultset + result.Resultset.Streaming = true + if err := c.readResultColumns(result); err != nil { return errors.Trace(err) } + if perResCb != nil { + if err := perResCb(result); err != nil { + return err + } + } + if err := c.readResultRowsStreaming(result, binary, perRowCb); err != nil { return errors.Trace(err) } + // this resultset is done streaming + result.Resultset.StreamingDone = true + return nil } diff --git a/client/stmt.go b/client/stmt.go index e0d5f1d30..c9f4a1754 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -39,6 +39,14 @@ func (s *Stmt) Execute(args ...interface{}) (*Result, error) { return s.conn.readResult(true) } +func (s *Stmt) ExecuteSelectStreaming(result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback, args ...interface{}) error { + if err := s.write(args...); err != nil { + return errors.Trace(err) + } + + return s.conn.readResultStreaming(true, result, perRowCb, perResCb) +} + func (s *Stmt) Close() error { if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil { return errors.Trace(err) diff --git a/mysql/resultset.go b/mysql/resultset.go index f244b7d06..6a4df6243 100644 --- a/mysql/resultset.go +++ b/mysql/resultset.go @@ -17,6 +17,9 @@ type Resultset struct { RawPkg []byte RowDatas []RowData + + Streaming bool + StreamingDone bool } var ( diff --git a/mysql/resultset_helper.go b/mysql/resultset_helper.go index 3c22c9c4f..0cc859d77 100644 --- a/mysql/resultset_helper.go +++ b/mysql/resultset_helper.go @@ -8,7 +8,7 @@ import ( "github.com/siddontang/go/hack" ) -func formatTextValue(value interface{}) ([]byte, error) { +func FormatTextValue(value interface{}) ([]byte, error) { switch v := value.(type) { case int8: return strconv.AppendInt(nil, int64(v), 10), nil @@ -165,7 +165,7 @@ func BuildSimpleTextResultset(names []string, values [][]interface{}) (*Resultse return nil, errors.Errorf("row types aren't consistent") } } - b, err = formatTextValue(value) + b, err = FormatTextValue(value) if err != nil { return nil, errors.Trace(err) diff --git a/server/command.go b/server/command.go index 69e0bf39c..3edc60fd4 100644 --- a/server/command.go +++ b/server/command.go @@ -44,7 +44,7 @@ func (c *Conn) HandleCommand() error { v := c.dispatch(data) - err = c.writeValue(v) + err = c.WriteValue(v) if c.Conn != nil { c.ResetSequence() diff --git a/server/resp.go b/server/resp.go index d6134b978..63c64bdea 100644 --- a/server/resp.go +++ b/server/resp.go @@ -116,6 +116,13 @@ func (c *Conn) writeAuthMoreDataFastAuth() error { } func (c *Conn) writeResultset(r *Resultset) error { + // for a streaming resultset, that handled rowdata separately in a callback + // of type SelectPerRowCallback, we can suffice by ending the stream with + // an EOF + if r.StreamingDone { + return c.writeEOF() + } + columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) data := make([]byte, 4, 1024) @@ -129,6 +136,12 @@ func (c *Conn) writeResultset(r *Resultset) error { return err } + // streaming resultsets handle rowdata in a separate callback of type + // SelectPerRowCallback so we're done here + if r.Streaming { + return nil + } + for _, v := range r.RowDatas { data = data[0:4] data = append(data, v...) @@ -163,10 +176,23 @@ func (c *Conn) writeFieldList(fs []*Field, data []byte) error { return nil } +func (c *Conn) writeFieldValues(fv []FieldValue) error { + data := make([]byte, 4, 1024) + for _, v := range fv { + tv, err := FormatTextValue(v.Value()) + if err != nil { + return err + } + data = append(data, PutLengthEncodedString(tv)...) + } + + return c.WritePacket(data) +} + type noResponse struct{} type eofResponse struct{} -func (c *Conn) writeValue(value interface{}) error { +func (c *Conn) WriteValue(value interface{}) error { switch v := value.(type) { case noResponse: return nil @@ -184,6 +210,8 @@ func (c *Conn) writeValue(value interface{}) error { } case []*Field: return c.writeFieldList(v, nil) + case []FieldValue: + return c.writeFieldValues(v) case *Stmt: return c.writePrepare(v) default: