Skip to content
This repository was archived by the owner on Jan 21, 2022. It is now read-only.

Commit 7c59dc2

Browse files
arnehormannDiego
authored and
Diego
committed
fewer driver.ErrBadConn to prevent repeated queries (go-sql-driver#302)
According to the database/sql/driver documentation, ErrBadConn should only be used when the database was not affected. The driver restarts the same query on a different connection, then. The mysql driver did not follow this advice, so queries were repeated if ErrBadConn is returned but a query succeeded. This is fixed by changing most ErrBadConn errors to ErrInvalidConn. The only valid returns of ErrBadConn are at the beginning of a database interaction when no data was sent to the database yet. Those valid cases are located the following funcs before attempting to write to the network or if 0 bytes were written: * Begin * BeginTx * Exec * ExecContext * Prepare * PrepareContext * Query * QueryContext Commit and Rollback could arguably also be on that list, but are left out as some engines like MyISAM are not supporting transactions. Tests in b/packets_test.go were changed because they simulate a read not preceded by a write to the db. This cannot happen as the client has to send the query first.
1 parent 50b949a commit 7c59dc2

File tree

5 files changed

+46
-28
lines changed

5 files changed

+46
-28
lines changed

Diff for: connection.go

+16-7
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ func (mc *mysqlConn) handleParams() (err error) {
8181
return
8282
}
8383

84+
func (mc *mysqlConn) markBadConn(err error) error {
85+
if mc == nil {
86+
return err
87+
}
88+
if err != errBadConnNoWrite {
89+
return err
90+
}
91+
return driver.ErrBadConn
92+
}
93+
8494
func (mc *mysqlConn) Begin() (driver.Tx, error) {
8595
if mc.closed.IsSet() {
8696
errLog.Print(ErrInvalidConn)
@@ -90,8 +100,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
90100
if err == nil {
91101
return &mysqlTx{mc}, err
92102
}
93-
94-
return nil, err
103+
return nil, mc.markBadConn(err)
95104
}
96105

97106
func (mc *mysqlConn) Close() (err error) {
@@ -142,7 +151,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
142151
// Send command
143152
err := mc.writeCommandPacketStr(comStmtPrepare, query)
144153
if err != nil {
145-
return nil, err
154+
return nil, mc.markBadConn(err)
146155
}
147156

148157
stmt := &mysqlStmt{
@@ -176,7 +185,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
176185
if buf == nil {
177186
// can not take the buffer. Something must be wrong with the connection
178187
errLog.Print(ErrBusyBuffer)
179-
return "", driver.ErrBadConn
188+
return "", ErrInvalidConn
180189
}
181190
buf = buf[:0]
182191
argPos := 0
@@ -314,14 +323,14 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
314323
insertId: int64(mc.insertId),
315324
}, err
316325
}
317-
return nil, err
326+
return nil, mc.markBadConn(err)
318327
}
319328

320329
// Internal function to execute commands
321330
func (mc *mysqlConn) exec(query string) error {
322331
// Send command
323332
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
324-
return err
333+
return mc.markBadConn(err)
325334
}
326335

327336
// Read Result
@@ -390,7 +399,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
390399
return rows, err
391400
}
392401
}
393-
return nil, err
402+
return nil, mc.markBadConn(err)
394403
}
395404

396405
// Gets the value of the given MySQL System Variable

Diff for: errors.go

+6
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ var (
3131
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
3232
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
3333
ErrBusyBuffer = errors.New("busy buffer")
34+
35+
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
36+
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
37+
// to trigger a resend.
38+
// See https://github.com/go-sql-driver/mysql/pull/302
39+
errBadConnNoWrite = errors.New("bad connection")
3440
)
3541

3642
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))

Diff for: packets.go

+16-12
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
3535
}
3636
errLog.Print(err)
3737
mc.Close()
38-
return nil, driver.ErrBadConn
38+
return nil, ErrInvalidConn
3939
}
4040

4141
// packet length [24 bit]
@@ -64,7 +64,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
6464
if prevData == nil {
6565
errLog.Print(ErrMalformPkt)
6666
mc.Close()
67-
return nil, driver.ErrBadConn
67+
return nil, ErrInvalidConn
6868
}
6969

7070
return prevData, nil
@@ -78,7 +78,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
7878
}
7979
errLog.Print(err)
8080
mc.Close()
81-
return nil, driver.ErrBadConn
81+
return nil, ErrInvalidConn
8282
}
8383

8484
// return data if this was the last packet
@@ -144,10 +144,14 @@ func (mc *mysqlConn) writePacket(data []byte) error {
144144
if cerr := mc.canceled.Value(); cerr != nil {
145145
return cerr
146146
}
147+
if n == 0 && pktLen == len(data)-4 {
148+
// only for the first loop iteration when nothing was written yet
149+
return errBadConnNoWrite
150+
}
147151
mc.cleanup()
148152
errLog.Print(err)
149153
}
150-
return driver.ErrBadConn
154+
return ErrInvalidConn
151155
}
152156
}
153157

@@ -281,7 +285,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
281285
if data == nil {
282286
// can not take the buffer. Something must be wrong with the connection
283287
errLog.Print(ErrBusyBuffer)
284-
return driver.ErrBadConn
288+
return errBadConnNoWrite
285289
}
286290

287291
// ClientFlags [32 bit]
@@ -369,7 +373,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
369373
if data == nil {
370374
// can not take the buffer. Something must be wrong with the connection
371375
errLog.Print(ErrBusyBuffer)
372-
return driver.ErrBadConn
376+
return errBadConnNoWrite
373377
}
374378

375379
// Add the scrambled password [null terminated string]
@@ -388,7 +392,7 @@ func (mc *mysqlConn) writeClearAuthPacket() error {
388392
if data == nil {
389393
// can not take the buffer. Something must be wrong with the connection
390394
errLog.Print(ErrBusyBuffer)
391-
return driver.ErrBadConn
395+
return errBadConnNoWrite
392396
}
393397

394398
// Add the clear password [null terminated string]
@@ -411,7 +415,7 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
411415
if data == nil {
412416
// can not take the buffer. Something must be wrong with the connection
413417
errLog.Print(ErrBusyBuffer)
414-
return driver.ErrBadConn
418+
return errBadConnNoWrite
415419
}
416420

417421
// Add the scramble
@@ -432,7 +436,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
432436
if data == nil {
433437
// can not take the buffer. Something must be wrong with the connection
434438
errLog.Print(ErrBusyBuffer)
435-
return driver.ErrBadConn
439+
return errBadConnNoWrite
436440
}
437441

438442
// Add command byte
@@ -451,7 +455,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
451455
if data == nil {
452456
// can not take the buffer. Something must be wrong with the connection
453457
errLog.Print(ErrBusyBuffer)
454-
return driver.ErrBadConn
458+
return errBadConnNoWrite
455459
}
456460

457461
// Add command byte
@@ -472,7 +476,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
472476
if data == nil {
473477
// can not take the buffer. Something must be wrong with the connection
474478
errLog.Print(ErrBusyBuffer)
475-
return driver.ErrBadConn
479+
return errBadConnNoWrite
476480
}
477481

478482
// Add command byte
@@ -945,7 +949,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
945949
if data == nil {
946950
// can not take the buffer. Something must be wrong with the connection
947951
errLog.Print(ErrBusyBuffer)
948-
return driver.ErrBadConn
952+
return errBadConnNoWrite
949953
}
950954

951955
// command [1 byte]

Diff for: packets_test.go

+6-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
package mysql
1010

1111
import (
12-
"database/sql/driver"
1312
"errors"
1413
"net"
1514
"testing"
@@ -252,8 +251,8 @@ func TestReadPacketFail(t *testing.T) {
252251
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
253252
conn.maxReads = 1
254253
_, err := mc.readPacket()
255-
if err != driver.ErrBadConn {
256-
t.Errorf("expected ErrBadConn, got %v", err)
254+
if err != ErrInvalidConn {
255+
t.Errorf("expected ErrInvalidConn, got %v", err)
257256
}
258257

259258
// reset
@@ -264,8 +263,8 @@ func TestReadPacketFail(t *testing.T) {
264263
// fail to read header
265264
conn.closed = true
266265
_, err = mc.readPacket()
267-
if err != driver.ErrBadConn {
268-
t.Errorf("expected ErrBadConn, got %v", err)
266+
if err != ErrInvalidConn {
267+
t.Errorf("expected ErrInvalidConn, got %v", err)
269268
}
270269

271270
// reset
@@ -277,7 +276,7 @@ func TestReadPacketFail(t *testing.T) {
277276
// fail to read body
278277
conn.maxReads = 1
279278
_, err = mc.readPacket()
280-
if err != driver.ErrBadConn {
281-
t.Errorf("expected ErrBadConn, got %v", err)
279+
if err != ErrInvalidConn {
280+
t.Errorf("expected ErrInvalidConn, got %v", err)
282281
}
283282
}

Diff for: statement.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
5252
// Send command
5353
err := stmt.writeExecutePacket(args)
5454
if err != nil {
55-
return nil, err
55+
return nil, stmt.mc.markBadConn(err)
5656
}
5757

5858
mc := stmt.mc
@@ -100,7 +100,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
100100
// Send command
101101
err := stmt.writeExecutePacket(args)
102102
if err != nil {
103-
return nil, err
103+
return nil, stmt.mc.markBadConn(err)
104104
}
105105

106106
mc := stmt.mc

0 commit comments

Comments
 (0)