Skip to content

Commit 41b8fbe

Browse files
committed
add client/Conn.ExecuteSelectStreaming() for yet more memory allocation optimization
1 parent 0407fd0 commit 41b8fbe

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

client/conn.go

+25
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ type Conn struct {
3333
connectionID uint32
3434
}
3535

36+
// This function will be called for every row in resultset from ExecuteSelectStreaming.
37+
type SelectPerRowCallback func(row []FieldValue) error
38+
3639
func getNetProto(addr string) string {
3740
proto := "tcp"
3841
if strings.Contains(addr, "/") {
@@ -165,6 +168,28 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
165168
}
166169
}
167170

171+
// ExecuteSelectStreaming will call perRowCallback for every row in resultset
172+
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
173+
//
174+
// ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving.
175+
//
176+
// Example:
177+
//
178+
// var result mysql.Result
179+
// conn.ExecuteSelectStreaming(`SELECT ... LIMIT 100500`, &result, func(row []mysql.FieldValue) error {
180+
// // Use the row as you want.
181+
// // You must not save FieldValue.AsString() value after this callback is done. Copy it if you need.
182+
// return nil
183+
// })
184+
//
185+
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback) error {
186+
if err := c.writeCommandStr(COM_QUERY, command); err != nil {
187+
return errors.Trace(err)
188+
}
189+
190+
return c.readResultStreaming(false, result, perRowCallback)
191+
}
192+
168193
func (c *Conn) Begin() error {
169194
_, err := c.exec("BEGIN")
170195
return errors.Trace(err)

client/resp.go

+88
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,25 @@ func (c *Conn) readResult(binary bool) (*Result, error) {
233233
return c.readResultset(firstPkgBuf, binary)
234234
}
235235

236+
func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback) error {
237+
firstPkgBuf, err := c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0])
238+
defer utils.ByteSlicePut(firstPkgBuf)
239+
240+
if err != nil {
241+
return errors.Trace(err)
242+
}
243+
244+
if firstPkgBuf[0] == OK_HEADER {
245+
return ErrMalformPacket // Streaming allowed only for SELECT queries
246+
} else if firstPkgBuf[0] == ERR_HEADER {
247+
return c.handleErrorPacket(append([]byte{}, firstPkgBuf...))
248+
} else if firstPkgBuf[0] == LocalInFile_HEADER {
249+
return ErrMalformPacket
250+
}
251+
252+
return c.readResultsetStreaming(firstPkgBuf, binary, result, perRowCb)
253+
}
254+
236255
func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
237256
// column count
238257
count, _, n := LengthEncodedInt(data)
@@ -256,6 +275,31 @@ func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
256275
return result, nil
257276
}
258277

278+
func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, perRowCb SelectPerRowCallback) error {
279+
columnCount, _, n := LengthEncodedInt(data)
280+
281+
if n-len(data) != 0 {
282+
return ErrMalformPacket
283+
}
284+
285+
if result.Resultset == nil {
286+
result.Resultset = NewResultset(int(columnCount))
287+
} else {
288+
// Reuse memory if can
289+
result.Reset(int(columnCount))
290+
}
291+
292+
if err := c.readResultColumns(result); err != nil {
293+
return errors.Trace(err)
294+
}
295+
296+
if err := c.readResultRowsStreaming(result, binary, perRowCb); err != nil {
297+
return errors.Trace(err)
298+
}
299+
300+
return nil
301+
}
302+
259303
func (c *Conn) readResultColumns(result *Result) (err error) {
260304
var i int = 0
261305
var data []byte
@@ -344,3 +388,47 @@ func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) {
344388

345389
return nil
346390
}
391+
392+
func (c *Conn) readResultRowsStreaming(result *Result, isBinary bool, perRowCb SelectPerRowCallback) (err error) {
393+
var (
394+
data []byte
395+
row []FieldValue
396+
)
397+
398+
for {
399+
data, err = c.ReadPacketReuseMem(data[:0])
400+
if err != nil {
401+
return
402+
}
403+
404+
// EOF Packet
405+
if c.isEOFPacket(data) {
406+
if c.capability&CLIENT_PROTOCOL_41 > 0 {
407+
// result.Warnings = binary.LittleEndian.Uint16(data[1:])
408+
// todo add strict_mode, warning will be treat as error
409+
result.Status = binary.LittleEndian.Uint16(data[3:])
410+
c.status = result.Status
411+
}
412+
413+
break
414+
}
415+
416+
if data[0] == ERR_HEADER {
417+
return c.handleErrorPacket(data)
418+
}
419+
420+
// Parse this row
421+
row, err = RowData(data).Parse(result.Fields, isBinary, row)
422+
if err != nil {
423+
return errors.Trace(err)
424+
}
425+
426+
// Send the row to "userland" code
427+
err = perRowCb(row)
428+
if err != nil {
429+
return errors.Trace(err)
430+
}
431+
}
432+
433+
return nil
434+
}

0 commit comments

Comments
 (0)