From 32e5ceed8e5ba6be7fb6f029ffed7517ab145528 Mon Sep 17 00:00:00 2001 From: Arne Hormann Date: Tue, 29 Oct 2013 11:57:58 +0100 Subject: [PATCH 1/3] no panic on closed connection reuse --- connection.go | 15 +++++++++++++++ driver_test.go | 34 ++++++++++++++++++++++++++++++++++ rows.go | 4 ++-- statement.go | 6 ++++++ transaction.go | 4 ++-- 5 files changed, 59 insertions(+), 4 deletions(-) diff --git a/connection.go b/connection.go index 99d690954..0add75238 100644 --- a/connection.go +++ b/connection.go @@ -99,6 +99,9 @@ func (mc *mysqlConn) handleParams() (err error) { } func (mc *mysqlConn) Begin() (driver.Tx, error) { + if mc.buf == nil { + return nil, errInvalidConn + } err := mc.exec("START TRANSACTION") if err == nil { return &mysqlTx{mc}, err @@ -108,6 +111,9 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { } func (mc *mysqlConn) Close() (err error) { + if mc.buf == nil { + return errInvalidConn + } // Makes Close idempotent if mc.netConn != nil { mc.writeCommandPacket(comQuit) @@ -122,6 +128,9 @@ func (mc *mysqlConn) Close() (err error) { } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + if mc.buf == nil { + return nil, errInvalidConn + } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { @@ -150,6 +159,9 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + if mc.buf == nil { + return nil, errInvalidConn + } if len(args) == 0 { // no args, fastpath mc.affectedRows = 0 mc.insertId = 0 @@ -191,6 +203,9 @@ func (mc *mysqlConn) exec(query string) error { } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + if mc.buf == nil { + return nil, errInvalidConn + } if len(args) == 0 { // no args, fastpath // Send command err := mc.writeCommandPacketStr(comQuery, query) diff --git a/driver_test.go b/driver_test.go index ee57ce02a..4bfb1398c 100644 --- a/driver_test.go +++ b/driver_test.go @@ -108,6 +108,40 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) return rows } +func TestClosedConnection(t *testing.T) { + // this test does not use sql.database, it uses the driver directly + if !available { + t.Skipf("MySQL-Server not running on %s", netAddr) + } + driver := &MySQLDriver{} + conn, err := driver.Open(dsn) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + stmt, err := conn.Prepare("SET @tmpif := 1") + if err != nil { + t.Fatalf("Error preparing statement: %s", err.Error()) + } + _, err = stmt.Exec(nil) + if err != nil { + t.Fatalf("Error executing statement: %s", err.Error()) + } + err = conn.Close() + if err != nil { + t.Fatalf("Error closing connection: %s", err.Error()) + } + defer func() { + if err := recover(); err != nil { + t.Errorf("Panic after reusing a closed connection: %v", err) + } + }() + _, err = stmt.Exec(nil) + if err != errInvalidConn { + t.Errorf("Unexpected error '%s', expected '%s'", + err.Error(), errInvalidConn.Error()) + } +} + func TestCharset(t *testing.T) { if !available { t.Skipf("MySQL-Server not running on %s", netAddr) diff --git a/rows.go b/rows.go index 73f55d210..dbc5a1d15 100644 --- a/rows.go +++ b/rows.go @@ -37,7 +37,7 @@ func (rows *mysqlRows) Columns() []string { func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.eof { - if rows.mc == nil || rows.mc.netConn == nil { + if rows.mc == nil || rows.mc.buf == nil { return errInvalidConn } @@ -58,7 +58,7 @@ func (rows *mysqlRows) Next(dest []driver.Value) (err error) { return io.EOF } - if rows.mc == nil || rows.mc.netConn == nil { + if rows.mc == nil || rows.mc.buf == nil { return errInvalidConn } diff --git a/statement.go b/statement.go index cd6030497..8a70a81ed 100644 --- a/statement.go +++ b/statement.go @@ -34,6 +34,9 @@ func (stmt *mysqlStmt) NumInput() int { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.mc.buf == nil { + return nil, errInvalidConn + } // Send command err := stmt.writeExecutePacket(args) if err != nil { @@ -70,6 +73,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.mc.buf == nil { + return nil, errInvalidConn + } // Send command err := stmt.writeExecutePacket(args) if err != nil { diff --git a/transaction.go b/transaction.go index 4cac59f3e..606a2cf56 100644 --- a/transaction.go +++ b/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.buf == nil { return errInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.buf == nil { return errInvalidConn } err = tx.mc.exec("ROLLBACK") From 06d5483c02ca247e1bf9672a7eec58dc601d5e66 Mon Sep 17 00:00:00 2001 From: Arne Hormann Date: Wed, 30 Oct 2013 11:21:35 +0100 Subject: [PATCH 2/3] changed mc.buf == nil to mc.netConn == nil --- connection.go | 11 ++++------- rows.go | 4 ++-- statement.go | 4 ++-- transaction.go | 4 ++-- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/connection.go b/connection.go index 0add75238..f770f782e 100644 --- a/connection.go +++ b/connection.go @@ -99,7 +99,7 @@ func (mc *mysqlConn) handleParams() (err error) { } func (mc *mysqlConn) Begin() (driver.Tx, error) { - if mc.buf == nil { + if mc.netConn == nil { return nil, errInvalidConn } err := mc.exec("START TRANSACTION") @@ -111,9 +111,6 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { } func (mc *mysqlConn) Close() (err error) { - if mc.buf == nil { - return errInvalidConn - } // Makes Close idempotent if mc.netConn != nil { mc.writeCommandPacket(comQuit) @@ -128,7 +125,7 @@ func (mc *mysqlConn) Close() (err error) { } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.buf == nil { + if mc.netConn == nil { return nil, errInvalidConn } // Send command @@ -159,7 +156,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.buf == nil { + if mc.netConn == nil { return nil, errInvalidConn } if len(args) == 0 { // no args, fastpath @@ -203,7 +200,7 @@ func (mc *mysqlConn) exec(query string) error { } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if mc.buf == nil { + if mc.netConn == nil { return nil, errInvalidConn } if len(args) == 0 { // no args, fastpath diff --git a/rows.go b/rows.go index dbc5a1d15..73f55d210 100644 --- a/rows.go +++ b/rows.go @@ -37,7 +37,7 @@ func (rows *mysqlRows) Columns() []string { func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.eof { - if rows.mc == nil || rows.mc.buf == nil { + if rows.mc == nil || rows.mc.netConn == nil { return errInvalidConn } @@ -58,7 +58,7 @@ func (rows *mysqlRows) Next(dest []driver.Value) (err error) { return io.EOF } - if rows.mc == nil || rows.mc.buf == nil { + if rows.mc == nil || rows.mc.netConn == nil { return errInvalidConn } diff --git a/statement.go b/statement.go index 8a70a81ed..bceb38917 100644 --- a/statement.go +++ b/statement.go @@ -34,7 +34,7 @@ func (stmt *mysqlStmt) NumInput() int { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.buf == nil { + if stmt.mc.netConn == nil { return nil, errInvalidConn } // Send command @@ -73,7 +73,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { - if stmt.mc.buf == nil { + if stmt.mc.netConn == nil { return nil, errInvalidConn } // Send command diff --git a/transaction.go b/transaction.go index 606a2cf56..4cac59f3e 100644 --- a/transaction.go +++ b/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.buf == nil { + if tx.mc == nil || tx.mc.netConn == nil { return errInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.buf == nil { + if tx.mc == nil || tx.mc.netConn == nil { return errInvalidConn } err = tx.mc.exec("ROLLBACK") From baf9f1cdabf952cf7b40c28a95543b0c205d905f Mon Sep 17 00:00:00 2001 From: Arne Hormann Date: Wed, 30 Oct 2013 13:00:55 +0100 Subject: [PATCH 3/3] fix criticisms in PR 143 --- driver_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/driver_test.go b/driver_test.go index 4bfb1398c..d4422075f 100644 --- a/driver_test.go +++ b/driver_test.go @@ -108,7 +108,7 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) return rows } -func TestClosedConnection(t *testing.T) { +func TestReuseClosedConnection(t *testing.T) { // this test does not use sql.database, it uses the driver directly if !available { t.Skipf("MySQL-Server not running on %s", netAddr) @@ -118,7 +118,7 @@ func TestClosedConnection(t *testing.T) { if err != nil { t.Fatalf("Error connecting: %s", err.Error()) } - stmt, err := conn.Prepare("SET @tmpif := 1") + stmt, err := conn.Prepare("DO 1") if err != nil { t.Fatalf("Error preparing statement: %s", err.Error()) } @@ -136,7 +136,7 @@ func TestClosedConnection(t *testing.T) { } }() _, err = stmt.Exec(nil) - if err != errInvalidConn { + if err != nil && err != errInvalidConn { t.Errorf("Unexpected error '%s', expected '%s'", err.Error(), errInvalidConn.Error()) }