Skip to content

Commit 9542bdd

Browse files
committed
Enable Multi Results support and discard additional results
- packets.go: flag clientMultiResults, update status when receiving an EOF packet, discard additional results on readRow when EOF is reached - statement.go: currently a nil rows.mc is used as an eof, don’t set it if there are no columns to avoid that Next() waits indefinitely - rows.go: discard additional results on close and avoid panic on Columns()
1 parent a197e5d commit 9542bdd

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

packets.go

+43-2
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
214214
clientLongPassword |
215215
clientTransactions |
216216
clientLocalFiles |
217+
clientMultiResults |
217218
mc.flags&clientLongFlag
218219

219220
if mc.cfg.clientFoundRows {
@@ -470,6 +471,10 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
470471
}
471472
}
472473

474+
func readStatus(b []byte) statusFlag {
475+
return statusFlag(b[0]) | statusFlag(b[1])<<8
476+
}
477+
473478
// Ok Packet
474479
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
475480
func (mc *mysqlConn) handleOkPacket(data []byte) error {
@@ -484,7 +489,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
484489
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
485490

486491
// server_status [2 bytes]
487-
mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
492+
mc.status = readStatus(data[1+n+m : 1+n+m+2])
488493

489494
// warning count [2 bytes]
490495
if !mc.strict {
@@ -603,6 +608,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
603608

604609
// EOF Packet
605610
if data[0] == iEOF && len(data) == 5 {
611+
// server_status [2 bytes]
612+
rows.mc.status = readStatus(data[3:])
613+
if err := rows.mc.discardMoreResultsIfExists(); err != nil {
614+
return err
615+
}
606616
rows.mc = nil
607617
return io.EOF
608618
}
@@ -660,6 +670,10 @@ func (mc *mysqlConn) readUntilEOF() error {
660670
if err == nil && data[0] != iEOF {
661671
continue
662672
}
673+
if err == nil && data[0] == iEOF && len(data) == 5 {
674+
mc.status = readStatus(data[3:])
675+
}
676+
663677
return err // Err or EOF
664678
}
665679
}
@@ -964,6 +978,28 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
964978
return mc.writePacket(data)
965979
}
966980

981+
func (mc *mysqlConn) discardMoreResultsIfExists() error {
982+
for mc.status&statusMoreResultsExists != 0 {
983+
resLen, err := mc.readResultSetHeaderPacket()
984+
if err != nil {
985+
return err
986+
}
987+
if resLen > 0 {
988+
// columns
989+
if err := mc.readUntilEOF(); err != nil {
990+
return err
991+
}
992+
// rows
993+
if err := mc.readUntilEOF(); err != nil {
994+
return err
995+
}
996+
} else {
997+
mc.status &^= statusMoreResultsExists
998+
}
999+
}
1000+
return nil
1001+
}
1002+
9671003
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
9681004
func (rows *binaryRows) readRow(dest []driver.Value) error {
9691005
data, err := rows.mc.readPacket()
@@ -973,11 +1009,16 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
9731009

9741010
// packet indicator [1 byte]
9751011
if data[0] != iOK {
976-
rows.mc = nil
9771012
// EOF Packet
9781013
if data[0] == iEOF && len(data) == 5 {
1014+
rows.mc.status = readStatus(data[3:])
1015+
if err := rows.mc.discardMoreResultsIfExists(); err != nil {
1016+
return err
1017+
}
1018+
rows.mc = nil
9791019
return io.EOF
9801020
}
1021+
rows.mc = nil
9811022

9821023
// Error otherwise
9831024
return rows.mc.handleErrorPacket(data)

rows.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type emptyRows struct{}
3838

3939
func (rows *mysqlRows) Columns() []string {
4040
columns := make([]string, len(rows.columns))
41-
if rows.mc.cfg.columnsWithAlias {
41+
if rows.mc != nil && rows.mc.cfg.columnsWithAlias {
4242
for i := range columns {
4343
columns[i] = rows.columns[i].tableName + "." + rows.columns[i].name
4444
}
@@ -61,6 +61,12 @@ func (rows *mysqlRows) Close() error {
6161

6262
// Remove unread packets from stream
6363
err := mc.readUntilEOF()
64+
if err == nil {
65+
if err = mc.discardMoreResultsIfExists(); err != nil {
66+
return err
67+
}
68+
}
69+
6470
rows.mc = nil
6571
return err
6672
}

statement.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
9494
}
9595

9696
rows := new(binaryRows)
97-
rows.mc = mc
9897

9998
if resLen > 0 {
99+
rows.mc = mc
100100
// Columns
101101
// If not cached, read them and cache them
102102
if stmt.columns == nil {

0 commit comments

Comments
 (0)