From c34acf9e4552ed8b8ebea4adebc2bd9466f01dd8 Mon Sep 17 00:00:00 2001 From: Brandon Bennett Date: Wed, 1 May 2019 10:04:44 -0600 Subject: [PATCH 1/2] Add support for OK packets representing EOF Fixes: #805 --- connection.go | 25 ++++++-------- packets.go | 95 +++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 84 insertions(+), 36 deletions(-) diff --git a/connection.go b/connection.go index 90aec6439..425cc5952 100644 --- a/connection.go +++ b/connection.go @@ -180,16 +180,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // Read Result columnCount, err := stmt.readPrepareResultPacket() - if err == nil { - if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { - return nil, err - } - } + if err != nil { + return stmt, err + } - if columnCount > 0 { - err = mc.readUntilEOF() - } + if err := mc.readPackets(stmt.paramCount); err != nil { + return nil, err + } + + if err := mc.readPackets(int(columnCount)); err != nil { + return nil, err } return stmt, err @@ -415,11 +415,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { rows.mc = mc rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} - if resLen > 0 { - // Columns - if err := mc.readUntilEOF(); err != nil { - return nil, err - } + if err := mc.readPackets(resLen); err != nil { + return nil, err } dest := make([]driver.Value, resLen) diff --git a/packets.go b/packets.go index 6664e5ae5..422278fe1 100644 --- a/packets.go +++ b/packets.go @@ -235,10 +235,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if len(data) > pos { // character set [1 byte] // status flags [2 bytes] + pos += 1 + 2 + // capability flags (upper 2 bytes) [2 bytes] + mc.flags += clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + pos += 2 + // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + pos += +1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -286,6 +291,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientLocalFiles | clientPluginAuth | clientMultiResults | + mc.flags&clientDeprecateEOF | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { @@ -610,18 +616,19 @@ func readStatus(b []byte) statusFlag { // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { - var n, m int - - // 0x00 [1 byte] - + // 0x00 or 0xFE [1 byte] + n := 1 + var l int // Affected rows [Length Coded Binary] - mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + mc.affectedRows, _, l = readLengthEncodedInteger(data[n:]) + n += l // Insert id [Length Coded Binary] - mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + mc.insertId, _, l = readLengthEncodedInteger(data[n:]) + n += l // server_status [2 bytes] - mc.status = readStatus(data[1+n+m : 1+n+m+2]) + mc.status = readStatus(data[n : n+2]) if mc.status&statusMoreResultsExists != 0 { return nil } @@ -631,19 +638,24 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { return nil } +// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet +// acting as an EOF. +func isEOFPacket(data []byte) bool { + return data[0] == iEOF && len(data) < 9 +} + // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; ; i++ { + for i := 0; i < count; i++ { data, err := mc.readPacket() if err != nil { return nil, err } - // EOF Packet - if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { + if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) { if i == count { return columns, nil } @@ -729,9 +741,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) //} } + return columns, nil } -// Read Packets as Field Packets until EOF-Packet or an Error appears +// Read Packets as Field Packets until EOF/OK-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc @@ -746,9 +759,15 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) + if isEOFPacket(data) { + if mc.flags&clientDeprecateEOF == 0 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + } else { + if err := mc.handleOkPacket(data); err != nil { + return err + } + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil @@ -808,18 +827,44 @@ func (mc *mysqlConn) readUntilEOF() error { return err } - switch data[0] { - case iERR: + switch { + case data[0] == iERR: return mc.handleErrorPacket(data) - case iEOF: - if len(data) == 5 { + case isEOFPacket(data): + if mc.flags&clientDeprecateEOF == 0 { mc.status = readStatus(data[3:]) + } else { + return mc.handleOkPacket(data) } return nil } } } +func (mc *mysqlConn) readPackets(num int) error { + + // we need to read EOF as well + if mc.flags&clientDeprecateEOF == 0 { + num++ + } + + for i := 0; i < num; i++ { + data, err := mc.readPacket() + if err != nil { + return err + } + + switch { + case data[0] == iERR: + return mc.handleErrorPacket(data) + case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data): + mc.status = readStatus(data[3:]) + return nil + } + } + return nil +} + /****************************************************************************** * Prepared Statements * ******************************************************************************/ @@ -1178,15 +1223,21 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - // EOF Packet - if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) + if isEOFPacket(data) { + if rows.mc.flags&clientDeprecateEOF == 0 { + rows.mc.status = readStatus(data[3:]) + } else { + if err := rows.mc.handleOkPacket(data); err != nil { + return err + } + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } + mc := rows.mc rows.mc = nil From 667a05a6dcef2df39d18a92a025065819bdb718f Mon Sep 17 00:00:00 2001 From: Tzu-Chiao Yeh Date: Sat, 5 Sep 2020 09:58:16 +0800 Subject: [PATCH 2/2] Ensure backward compatibility with legacy EOF format --- connection.go | 33 +++++++++++++++++++---------- packets.go | 57 +++++++++++++++++++++++++-------------------------- rows.go | 1 - statement.go | 3 +-- 4 files changed, 51 insertions(+), 43 deletions(-) diff --git a/connection.go b/connection.go index 425cc5952..39f24b15d 100644 --- a/connection.go +++ b/connection.go @@ -180,16 +180,24 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // Read Result columnCount, err := stmt.readPrepareResultPacket() - if err != nil { - return stmt, err - } - - if err := mc.readPackets(stmt.paramCount); err != nil { - return nil, err - } + if err == nil { + if stmt.paramCount > 0 { + // FIXME - seems like a bug in MySQL (or it's intended). + // There's no EOF return after parameters. + // However, this behavior isn't consistent to Maria DB. + if mc.flags&clientDeprecateEOF == 0 { + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + } + if err = mc.readExactPackets(stmt.paramCount); err != nil { + return nil, err + } + } - if err := mc.readPackets(int(columnCount)); err != nil { - return nil, err + if columnCount > 0 { + err = mc.readUntilEOF() + } } return stmt, err @@ -415,8 +423,11 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { rows.mc = mc rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} - if err := mc.readPackets(resLen); err != nil { - return nil, err + if resLen > 0 { + // Columns + if err := mc.readUntilEOF(); err != nil { + return nil, err + } } dest := make([]driver.Value, resLen) diff --git a/packets.go b/packets.go index 422278fe1..b3d3b4850 100644 --- a/packets.go +++ b/packets.go @@ -238,12 +238,12 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro pos += 1 + 2 // capability flags (upper 2 bytes) [2 bytes] - mc.flags += clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 pos += 2 // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += +1 + 10 + pos += 1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -614,7 +614,7 @@ func readStatus(b []byte) statusFlag { } // Ok Packet -// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html func (mc *mysqlConn) handleOkPacket(data []byte) error { // 0x00 or 0xFE [1 byte] n := 1 @@ -640,8 +640,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet // acting as an EOF. -func isEOFPacket(data []byte) bool { - return data[0] == iEOF && len(data) < 9 +func (mc *mysqlConn) isEOFPacket(data []byte) bool { + // Legacy EOF packet + if data[0] == iEOF && (len(data) == 5 || len(data) == 1) && mc.flags&clientDeprecateEOF == 0 { + return true + } + return data[0] == iEOF && len(data) < 9 && mc.flags&clientDeprecateEOF != 0 } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -649,13 +653,21 @@ func isEOFPacket(data []byte) bool { func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; i < count; i++ { + // If we set clientDeprecateEOF capability flag, + // the EOF will be no longer sent after all columns. + packets := count + if mc.flags&clientDeprecateEOF == 0 { + // Legacy way, read one more EOF packet. + packets += 1 + } + + for i := 0; i < packets; i++ { data, err := mc.readPacket() if err != nil { return nil, err } - if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) { + if mc.isEOFPacket(data) { if i == count { return columns, nil } @@ -759,12 +771,13 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if isEOFPacket(data) { + if mc.isEOFPacket(data) { if mc.flags&clientDeprecateEOF == 0 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) } else { if err := mc.handleOkPacket(data); err != nil { + rows.mc = nil return err } } @@ -830,37 +843,22 @@ func (mc *mysqlConn) readUntilEOF() error { switch { case data[0] == iERR: return mc.handleErrorPacket(data) - case isEOFPacket(data): + case mc.isEOFPacket(data): if mc.flags&clientDeprecateEOF == 0 { mc.status = readStatus(data[3:]) - } else { - return mc.handleOkPacket(data) + return nil } - return nil + return mc.handleOkPacket(data) } } } -func (mc *mysqlConn) readPackets(num int) error { - - // we need to read EOF as well - if mc.flags&clientDeprecateEOF == 0 { - num++ - } - +func (mc *mysqlConn) readExactPackets(num int) error { for i := 0; i < num; i++ { - data, err := mc.readPacket() + _, err := mc.readPacket() if err != nil { return err } - - switch { - case data[0] == iERR: - return mc.handleErrorPacket(data) - case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data): - mc.status = readStatus(data[3:]) - return nil - } } return nil } @@ -1223,11 +1221,12 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - if isEOFPacket(data) { + if rows.mc.isEOFPacket(data) { if rows.mc.flags&clientDeprecateEOF == 0 { rows.mc.status = readStatus(data[3:]) } else { if err := rows.mc.handleOkPacket(data); err != nil { + rows.mc = nil return err } } diff --git a/rows.go b/rows.go index 888bdb5f0..1599ee03b 100644 --- a/rows.go +++ b/rows.go @@ -215,7 +215,6 @@ func (rows *textRows) Next(dest []driver.Value) error { if err := mc.error(); err != nil { return err } - // Fetch next row from stream return rows.readRow(dest) } diff --git a/statement.go b/statement.go index 18a3ae498..cc7c93f36 100644 --- a/statement.go +++ b/statement.go @@ -73,10 +73,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if resLen > 0 { // Columns - if err = mc.readUntilEOF(); err != nil { + if err = mc.readExactPackets(resLen); err != nil { return nil, err } - // Rows if err := mc.readUntilEOF(); err != nil { return nil, err