diff --git a/connection.go b/connection.go index 99d690954..f770f782e 100644 --- a/connection.go +++ b/connection.go @@ -99,6 +99,9 @@ func (mc *mysqlConn) handleParams() (err error) { } func (mc *mysqlConn) Begin() (driver.Tx, error) { + if mc.netConn == nil { + return nil, errInvalidConn + } err := mc.exec("START TRANSACTION") if err == nil { return &mysqlTx{mc}, err @@ -122,6 +125,9 @@ func (mc *mysqlConn) Close() (err error) { } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + if mc.netConn == nil { + return nil, errInvalidConn + } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { @@ -150,6 +156,9 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + if mc.netConn == nil { + return nil, errInvalidConn + } if len(args) == 0 { // no args, fastpath mc.affectedRows = 0 mc.insertId = 0 @@ -191,6 +200,9 @@ func (mc *mysqlConn) exec(query string) error { } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + if mc.netConn == nil { + return nil, errInvalidConn + } if len(args) == 0 { // no args, fastpath // Send command err := mc.writeCommandPacketStr(comQuery, query) diff --git a/driver_test.go b/driver_test.go index ee57ce02a..d4422075f 100644 --- a/driver_test.go +++ b/driver_test.go @@ -108,6 +108,40 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) return rows } +func TestReuseClosedConnection(t *testing.T) { + // this test does not use sql.database, it uses the driver directly + if !available { + t.Skipf("MySQL-Server not running on %s", netAddr) + } + driver := &MySQLDriver{} + conn, err := driver.Open(dsn) + if err != nil { + t.Fatalf("Error connecting: %s", err.Error()) + } + stmt, err := conn.Prepare("DO 1") + if err != nil { + t.Fatalf("Error preparing statement: %s", err.Error()) + } + _, err = stmt.Exec(nil) + if err != nil { + t.Fatalf("Error executing statement: %s", err.Error()) + } + err = conn.Close() + if err != nil { + t.Fatalf("Error closing connection: %s", err.Error()) + } + defer func() { + if err := recover(); err != nil { + t.Errorf("Panic after reusing a closed connection: %v", err) + } + }() + _, err = stmt.Exec(nil) + if err != nil && err != errInvalidConn { + t.Errorf("Unexpected error '%s', expected '%s'", + err.Error(), errInvalidConn.Error()) + } +} + func TestCharset(t *testing.T) { if !available { t.Skipf("MySQL-Server not running on %s", netAddr) diff --git a/statement.go b/statement.go index cd6030497..bceb38917 100644 --- a/statement.go +++ b/statement.go @@ -34,6 +34,9 @@ func (stmt *mysqlStmt) NumInput() int { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.mc.netConn == nil { + return nil, errInvalidConn + } // Send command err := stmt.writeExecutePacket(args) if err != nil { @@ -70,6 +73,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.mc.netConn == nil { + return nil, errInvalidConn + } // Send command err := stmt.writeExecutePacket(args) if err != nil {