Skip to content

Commit c57020b

Browse files
committed
WIP: Add support for OK packets representing EOF
Fixes: go-sql-driver#805
1 parent d0a5481 commit c57020b

File tree

2 files changed

+84
-36
lines changed

2 files changed

+84
-36
lines changed

connection.go

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
163163

164164
// Read Result
165165
columnCount, err := stmt.readPrepareResultPacket()
166-
if err == nil {
167-
if stmt.paramCount > 0 {
168-
if err = mc.readUntilEOF(); err != nil {
169-
return nil, err
170-
}
171-
}
166+
if err != nil {
167+
return stmt, err
168+
}
172169

173-
if columnCount > 0 {
174-
err = mc.readUntilEOF()
175-
}
170+
if err := mc.readPackets(stmt.paramCount); err != nil {
171+
return nil, err
172+
}
173+
174+
if err := mc.readPackets(int(columnCount)); err != nil {
175+
return nil, err
176176
}
177177

178178
return stmt, err
@@ -424,11 +424,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
424424
rows.mc = mc
425425
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
426426

427-
if resLen > 0 {
428-
// Columns
429-
if err := mc.readUntilEOF(); err != nil {
430-
return nil, err
431-
}
427+
if err := mc.readPackets(resLen); err != nil {
428+
return nil, err
432429
}
433430

434431
dest := make([]driver.Value, resLen)

packets.go

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
224224
if len(data) > pos {
225225
// character set [1 byte]
226226
// status flags [2 bytes]
227+
pos += 1 + 2
228+
227229
// capability flags (upper 2 bytes) [2 bytes]
230+
mc.flags += clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
231+
pos += 2
232+
228233
// length of auth-plugin-data [1 byte]
229234
// reserved (all [00]) [10 bytes]
230-
pos += 1 + 2 + 2 + 1 + 10
235+
pos += +1 + 10
231236

232237
// second part of the password cipher [mininum 13 bytes],
233238
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -275,6 +280,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
275280
clientLocalFiles |
276281
clientPluginAuth |
277282
clientMultiResults |
283+
mc.flags&clientDeprecateEOF |
278284
mc.flags&clientLongFlag
279285

280286
if mc.cfg.ClientFoundRows {
@@ -599,18 +605,19 @@ func readStatus(b []byte) statusFlag {
599605
// Ok Packet
600606
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
601607
func (mc *mysqlConn) handleOkPacket(data []byte) error {
602-
var n, m int
603-
604-
// 0x00 [1 byte]
605-
608+
// 0x00 or 0xFE [1 byte]
609+
n := 1
610+
var l int
606611
// Affected rows [Length Coded Binary]
607-
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
612+
mc.affectedRows, _, l = readLengthEncodedInteger(data[n:])
613+
n += l
608614

609615
// Insert id [Length Coded Binary]
610-
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
616+
mc.insertId, _, l = readLengthEncodedInteger(data[n:])
617+
n += l
611618

612619
// server_status [2 bytes]
613-
mc.status = readStatus(data[1+n+m : 1+n+m+2])
620+
mc.status = readStatus(data[n : n+2])
614621
if mc.status&statusMoreResultsExists != 0 {
615622
return nil
616623
}
@@ -620,19 +627,24 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
620627
return nil
621628
}
622629

630+
// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet
631+
// acting as an EOF.
632+
func isEOFPacket(data []byte) bool {
633+
return data[0] == iEOF && len(data) < 9
634+
}
635+
623636
// Read Packets as Field Packets until EOF-Packet or an Error appears
624637
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
625638
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
626639
columns := make([]mysqlField, count)
627640

628-
for i := 0; ; i++ {
641+
for i := 0; i < count; i++ {
629642
data, err := mc.readPacket()
630643
if err != nil {
631644
return nil, err
632645
}
633646

634-
// EOF Packet
635-
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
647+
if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) {
636648
if i == count {
637649
return columns, nil
638650
}
@@ -718,9 +730,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
718730
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
719731
//}
720732
}
733+
return columns, nil
721734
}
722735

723-
// Read Packets as Field Packets until EOF-Packet or an Error appears
736+
// Read Packets as Field Packets until EOF/OK-Packet or an Error appears
724737
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
725738
func (rows *textRows) readRow(dest []driver.Value) error {
726739
mc := rows.mc
@@ -735,9 +748,15 @@ func (rows *textRows) readRow(dest []driver.Value) error {
735748
}
736749

737750
// EOF Packet
738-
if data[0] == iEOF && len(data) == 5 {
739-
// server_status [2 bytes]
740-
rows.mc.status = readStatus(data[3:])
751+
if isEOFPacket(data) {
752+
if mc.flags&clientDeprecateEOF == 0 {
753+
// server_status [2 bytes]
754+
rows.mc.status = readStatus(data[3:])
755+
} else {
756+
if err := mc.handleOkPacket(data); err != nil {
757+
return err
758+
}
759+
}
741760
rows.rs.done = true
742761
if !rows.HasNextResultSet() {
743762
rows.mc = nil
@@ -797,18 +816,44 @@ func (mc *mysqlConn) readUntilEOF() error {
797816
return err
798817
}
799818

800-
switch data[0] {
801-
case iERR:
819+
switch {
820+
case data[0] == iERR:
802821
return mc.handleErrorPacket(data)
803-
case iEOF:
804-
if len(data) == 5 {
822+
case isEOFPacket(data):
823+
if mc.flags&clientDeprecateEOF == 0 {
805824
mc.status = readStatus(data[3:])
825+
} else {
826+
return mc.handleOkPacket(data)
806827
}
807828
return nil
808829
}
809830
}
810831
}
811832

833+
func (mc *mysqlConn) readPackets(num int) error {
834+
835+
// we need to read EOF as well
836+
if mc.flags&clientDeprecateEOF == 0 {
837+
num++
838+
}
839+
840+
for i := 0; i < num; i++ {
841+
data, err := mc.readPacket()
842+
if err != nil {
843+
return err
844+
}
845+
846+
switch {
847+
case data[0] == iERR:
848+
return mc.handleErrorPacket(data)
849+
case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data):
850+
mc.status = readStatus(data[3:])
851+
return nil
852+
}
853+
}
854+
return nil
855+
}
856+
812857
/******************************************************************************
813858
* Prepared Statements *
814859
******************************************************************************/
@@ -1161,15 +1206,21 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11611206

11621207
// packet indicator [1 byte]
11631208
if data[0] != iOK {
1164-
// EOF Packet
1165-
if data[0] == iEOF && len(data) == 5 {
1166-
rows.mc.status = readStatus(data[3:])
1209+
if isEOFPacket(data) {
1210+
if rows.mc.flags&clientDeprecateEOF == 0 {
1211+
rows.mc.status = readStatus(data[3:])
1212+
} else {
1213+
if err := rows.mc.handleOkPacket(data); err != nil {
1214+
return err
1215+
}
1216+
}
11671217
rows.rs.done = true
11681218
if !rows.HasNextResultSet() {
11691219
rows.mc = nil
11701220
}
11711221
return io.EOF
11721222
}
1223+
11731224
mc := rows.mc
11741225
rows.mc = nil
11751226

0 commit comments

Comments
 (0)