Skip to content

Commit 32e5cee

Browse files
committed
no panic on closed connection reuse
1 parent ae73333 commit 32e5cee

File tree

5 files changed

+59
-4
lines changed

5 files changed

+59
-4
lines changed

connection.go

+15
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ func (mc *mysqlConn) handleParams() (err error) {
9999
}
100100

101101
func (mc *mysqlConn) Begin() (driver.Tx, error) {
102+
if mc.buf == nil {
103+
return nil, errInvalidConn
104+
}
102105
err := mc.exec("START TRANSACTION")
103106
if err == nil {
104107
return &mysqlTx{mc}, err
@@ -108,6 +111,9 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
108111
}
109112

110113
func (mc *mysqlConn) Close() (err error) {
114+
if mc.buf == nil {
115+
return errInvalidConn
116+
}
111117
// Makes Close idempotent
112118
if mc.netConn != nil {
113119
mc.writeCommandPacket(comQuit)
@@ -122,6 +128,9 @@ func (mc *mysqlConn) Close() (err error) {
122128
}
123129

124130
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
131+
if mc.buf == nil {
132+
return nil, errInvalidConn
133+
}
125134
// Send command
126135
err := mc.writeCommandPacketStr(comStmtPrepare, query)
127136
if err != nil {
@@ -150,6 +159,9 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
150159
}
151160

152161
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
162+
if mc.buf == nil {
163+
return nil, errInvalidConn
164+
}
153165
if len(args) == 0 { // no args, fastpath
154166
mc.affectedRows = 0
155167
mc.insertId = 0
@@ -191,6 +203,9 @@ func (mc *mysqlConn) exec(query string) error {
191203
}
192204

193205
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
206+
if mc.buf == nil {
207+
return nil, errInvalidConn
208+
}
194209
if len(args) == 0 { // no args, fastpath
195210
// Send command
196211
err := mc.writeCommandPacketStr(comQuery, query)

driver_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,40 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
108108
return rows
109109
}
110110

111+
func TestClosedConnection(t *testing.T) {
112+
// this test does not use sql.database, it uses the driver directly
113+
if !available {
114+
t.Skipf("MySQL-Server not running on %s", netAddr)
115+
}
116+
driver := &MySQLDriver{}
117+
conn, err := driver.Open(dsn)
118+
if err != nil {
119+
t.Fatalf("Error connecting: %s", err.Error())
120+
}
121+
stmt, err := conn.Prepare("SET @tmpif := 1")
122+
if err != nil {
123+
t.Fatalf("Error preparing statement: %s", err.Error())
124+
}
125+
_, err = stmt.Exec(nil)
126+
if err != nil {
127+
t.Fatalf("Error executing statement: %s", err.Error())
128+
}
129+
err = conn.Close()
130+
if err != nil {
131+
t.Fatalf("Error closing connection: %s", err.Error())
132+
}
133+
defer func() {
134+
if err := recover(); err != nil {
135+
t.Errorf("Panic after reusing a closed connection: %v", err)
136+
}
137+
}()
138+
_, err = stmt.Exec(nil)
139+
if err != errInvalidConn {
140+
t.Errorf("Unexpected error '%s', expected '%s'",
141+
err.Error(), errInvalidConn.Error())
142+
}
143+
}
144+
111145
func TestCharset(t *testing.T) {
112146
if !available {
113147
t.Skipf("MySQL-Server not running on %s", netAddr)

rows.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (rows *mysqlRows) Columns() []string {
3737
func (rows *mysqlRows) Close() (err error) {
3838
// Remove unread packets from stream
3939
if !rows.eof {
40-
if rows.mc == nil || rows.mc.netConn == nil {
40+
if rows.mc == nil || rows.mc.buf == nil {
4141
return errInvalidConn
4242
}
4343

@@ -58,7 +58,7 @@ func (rows *mysqlRows) Next(dest []driver.Value) (err error) {
5858
return io.EOF
5959
}
6060

61-
if rows.mc == nil || rows.mc.netConn == nil {
61+
if rows.mc == nil || rows.mc.buf == nil {
6262
return errInvalidConn
6363
}
6464

statement.go

+6
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ func (stmt *mysqlStmt) NumInput() int {
3434
}
3535

3636
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
37+
if stmt.mc.buf == nil {
38+
return nil, errInvalidConn
39+
}
3740
// Send command
3841
err := stmt.writeExecutePacket(args)
3942
if err != nil {
@@ -70,6 +73,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
7073
}
7174

7275
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
76+
if stmt.mc.buf == nil {
77+
return nil, errInvalidConn
78+
}
7379
// Send command
7480
err := stmt.writeExecutePacket(args)
7581
if err != nil {

transaction.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type mysqlTx struct {
1313
}
1414

1515
func (tx *mysqlTx) Commit() (err error) {
16-
if tx.mc == nil || tx.mc.netConn == nil {
16+
if tx.mc == nil || tx.mc.buf == nil {
1717
return errInvalidConn
1818
}
1919
err = tx.mc.exec("COMMIT")
@@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
2222
}
2323

2424
func (tx *mysqlTx) Rollback() (err error) {
25-
if tx.mc == nil || tx.mc.netConn == nil {
25+
if tx.mc == nil || tx.mc.buf == nil {
2626
return errInvalidConn
2727
}
2828
err = tx.mc.exec("ROLLBACK")

0 commit comments

Comments
 (0)