From 2b7a3e9811653b3161c0ed78c2588a3abcc4d269 Mon Sep 17 00:00:00 2001 From: Arne Hormann Date: Fri, 23 Jun 2017 22:08:11 +0200 Subject: [PATCH] fewer driver.ErrBadConn to prevent repeated queries According to the database/sql/driver documentation, ErrBadConn should only be used when the database was not affected. The driver restarts the same query on a different connection, then. The mysql driver did not follow this advice, so queries were repeated if ErrBadConn is returned but a query succeeded. This is fixed by changing most ErrBadConn errors to ErrInvalidConn. The only valid returns of ErrBadConn are at the beginning of a database interaction when no data was sent to the database yet. Those valid cases are located the following funcs before attempting to write to the network or if 0 bytes were written: * Begin * BeginTx * Exec * ExecContext * Prepare * PrepareContext * Query * QueryContext Commit and Rollback could arguably also be on that list, but are left out as some engines like MyISAM are not supporting transactions. Tests in b/packets_test.go were changed because they simulate a read not preceded by a write to the db. This cannot happen as the client has to send the query first. --- connection.go | 23 ++++++++++++++++------- errors.go | 6 ++++++ packets.go | 28 ++++++++++++++++------------ packets_test.go | 13 ++++++------- statement.go | 4 ++-- 5 files changed, 46 insertions(+), 28 deletions(-) diff --git a/connection.go b/connection.go index 2630f5211..b07528653 100644 --- a/connection.go +++ b/connection.go @@ -81,6 +81,16 @@ func (mc *mysqlConn) handleParams() (err error) { return } +func (mc *mysqlConn) markBadConn(err error) error { + if mc == nil { + return err + } + if err != errBadConnNoWrite { + return err + } + return driver.ErrBadConn +} + func (mc *mysqlConn) Begin() (driver.Tx, error) { if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) @@ -90,8 +100,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { if err == nil { return &mysqlTx{mc}, err } - - return nil, err + return nil, mc.markBadConn(err) } func (mc *mysqlConn) Close() (err error) { @@ -142,7 +151,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { - return nil, err + return nil, mc.markBadConn(err) } stmt := &mysqlStmt{ @@ -176,7 +185,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if buf == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return "", driver.ErrBadConn + return "", ErrInvalidConn } buf = buf[:0] argPos := 0 @@ -314,14 +323,14 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err insertId: int64(mc.insertId), }, err } - return nil, err + return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command if err := mc.writeCommandPacketStr(comQuery, query); err != nil { - return err + return mc.markBadConn(err) } // Read Result @@ -390,7 +399,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) return rows, err } } - return nil, err + return nil, mc.markBadConn(err) } // Gets the value of the given MySQL System Variable diff --git a/errors.go b/errors.go index 857854e14..d0d0d2e11 100644 --- a/errors.go +++ b/errors.go @@ -31,6 +31,12 @@ var ( ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") ErrBusyBuffer = errors.New("busy buffer") + + // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. + // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn + // to trigger a resend. + // See https://github.com/go-sql-driver/mysql/pull/302 + errBadConnNoWrite = errors.New("bad connection") ) var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) diff --git a/packets.go b/packets.go index 9715067c4..eef480b79 100644 --- a/packets.go +++ b/packets.go @@ -35,7 +35,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } // packet length [24 bit] @@ -57,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if prevData == nil { errLog.Print(ErrMalformPkt) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } return prevData, nil @@ -71,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } // return data if this was the last packet @@ -137,10 +137,14 @@ func (mc *mysqlConn) writePacket(data []byte) error { if cerr := mc.canceled.Value(); cerr != nil { return cerr } + if n == 0 && pktLen == len(data)-4 { + // only for the first loop iteration when nothing was written yet + return errBadConnNoWrite + } mc.cleanup() errLog.Print(err) } - return driver.ErrBadConn + return ErrInvalidConn } } @@ -274,7 +278,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // ClientFlags [32 bit] @@ -360,7 +364,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add the scrambled password [null terminated string] @@ -379,7 +383,7 @@ func (mc *mysqlConn) writeClearAuthPacket() error { if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add the clear password [null terminated string] @@ -400,7 +404,7 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add the scramble @@ -421,7 +425,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add command byte @@ -440,7 +444,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add command byte @@ -461,7 +465,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add command byte @@ -927,7 +931,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if data == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // command [1 byte] diff --git a/packets_test.go b/packets_test.go index 31c892d85..2f8207511 100644 --- a/packets_test.go +++ b/packets_test.go @@ -9,7 +9,6 @@ package mysql import ( - "database/sql/driver" "errors" "net" "testing" @@ -252,8 +251,8 @@ func TestReadPacketFail(t *testing.T) { conn.data = []byte{0x00, 0x00, 0x00, 0x00} conn.maxReads = 1 _, err := mc.readPacket() - if err != driver.ErrBadConn { - t.Errorf("expected ErrBadConn, got %v", err) + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) } // reset @@ -264,8 +263,8 @@ func TestReadPacketFail(t *testing.T) { // fail to read header conn.closed = true _, err = mc.readPacket() - if err != driver.ErrBadConn { - t.Errorf("expected ErrBadConn, got %v", err) + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) } // reset @@ -277,7 +276,7 @@ func TestReadPacketFail(t *testing.T) { // fail to read body conn.maxReads = 1 _, err = mc.readPacket() - if err != driver.ErrBadConn { - t.Errorf("expected ErrBadConn, got %v", err) + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) } } diff --git a/statement.go b/statement.go index ae6d33b72..ae223507f 100644 --- a/statement.go +++ b/statement.go @@ -52,7 +52,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, err + return nil, stmt.mc.markBadConn(err) } mc := stmt.mc @@ -100,7 +100,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, err + return nil, stmt.mc.markBadConn(err) } mc := stmt.mc