diff --git a/connection.go b/connection.go index f770f782e..77a7f24e6 100644 --- a/connection.go +++ b/connection.go @@ -211,7 +211,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro var resLen int resLen, err = mc.readResultSetHeaderPacket() if err == nil { - rows := &mysqlRows{mc, nil, false, false} + rows := &mysqlRows{mc, nil, false} if resLen > 0 { // Columns @@ -238,7 +238,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() if err == nil { - rows := &mysqlRows{mc, nil, false, false} + rows := &mysqlRows{mc, nil, false} if resLen > 0 { // Columns diff --git a/rows.go b/rows.go index 73f55d210..b76118758 100644 --- a/rows.go +++ b/rows.go @@ -23,7 +23,6 @@ type mysqlRows struct { mc *mysqlConn columns []mysqlField binary bool - eof bool } func (rows *mysqlRows) Columns() []string { @@ -34,43 +33,36 @@ func (rows *mysqlRows) Columns() []string { return columns } -func (rows *mysqlRows) Close() (err error) { - // Remove unread packets from stream - if !rows.eof { - if rows.mc == nil || rows.mc.netConn == nil { - return errInvalidConn - } - - err = rows.mc.readUntilEOF() - - // explicitly set because readUntilEOF might return early in case of an - // error - rows.eof = true +func (rows *mysqlRows) Close() error { + mc := rows.mc + if mc == nil { + return nil } - + if mc.netConn == nil { + return errInvalidConn + } + // Remove unread packets from stream + err := mc.readUntilEOF() rows.mc = nil - - return + return err } func (rows *mysqlRows) Next(dest []driver.Value) (err error) { - if rows.eof { + mc := rows.mc + if mc == nil { return io.EOF } - - if rows.mc == nil || rows.mc.netConn == nil { + if mc.netConn == nil { return errInvalidConn } - // Fetch next row from stream if rows.binary { err = rows.readBinaryRow(dest) } else { err = rows.readRow(dest) } - if err == io.EOF { - rows.eof = true + rows.mc = nil } - return err + return } diff --git a/statement.go b/statement.go index bceb38917..4a8e0dfe5 100644 --- a/statement.go +++ b/statement.go @@ -90,7 +90,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, err } - rows := &mysqlRows{mc, nil, true, false} + rows := &mysqlRows{mc, nil, true} if resLen > 0 { // Columns