Skip to content

Commit aea98e2

Browse files
committed
remove errBadConnNoWrite and markBadConn
1 parent 52c1917 commit aea98e2

File tree

5 files changed

+43
-62
lines changed

5 files changed

+43
-62
lines changed

connection.go

+7-15
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,6 @@ func (mc *mysqlConn) handleParams() (err error) {
111111
return
112112
}
113113

114-
// markBadConn replaces errBadConnNoWrite with driver.ErrBadConn.
115-
// This function is used to return driver.ErrBadConn only when safe to retry.
116-
func (mc *mysqlConn) markBadConn(err error) error {
117-
if err == errBadConnNoWrite {
118-
return driver.ErrBadConn
119-
}
120-
return err
121-
}
122-
123114
func (mc *mysqlConn) Begin() (driver.Tx, error) {
124115
return mc.begin(false)
125116
}
@@ -138,7 +129,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
138129
if err == nil {
139130
return &mysqlTx{mc}, err
140131
}
141-
return nil, mc.markBadConn(err)
132+
return nil, err
142133
}
143134

144135
func (mc *mysqlConn) Close() (err error) {
@@ -340,15 +331,15 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
340331
copied := mc.result
341332
return &copied, err
342333
}
343-
return nil, mc.markBadConn(err)
334+
return nil, err
344335
}
345336

346337
// Internal function to execute commands
347338
func (mc *mysqlConn) exec(query string) error {
348339
handleOk := mc.clearResult()
349340
// Send command
350341
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
351-
return mc.markBadConn(err)
342+
return err
352343
}
353344

354345
// Read Result
@@ -378,10 +369,10 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
378369

379370
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
380371
handleOk := mc.clearResult()
381-
382372
if mc.closed.Load() {
383373
return nil, driver.ErrBadConn
384374
}
375+
385376
if len(args) != 0 {
386377
if !mc.cfg.InterpolateParams {
387378
return nil, driver.ErrSkip
@@ -393,10 +384,11 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
393384
}
394385
query = prepared
395386
}
387+
396388
// Send command
397389
err := mc.writeCommandPacketStr(comQuery, query)
398390
if err != nil {
399-
return nil, mc.markBadConn(err)
391+
return nil, err
400392
}
401393

402394
// Read Result
@@ -487,7 +479,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
487479

488480
handleOk := mc.clearResult()
489481
if err = mc.writeCommandPacket(comPing); err != nil {
490-
return mc.markBadConn(err)
482+
return err
491483
}
492484

493485
return handleOk.readResultOK()

connection_test.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,13 @@ func TestPingMarkBadConnection(t *testing.T) {
163163
netConn: nc,
164164
buf: newBuffer(nc),
165165
maxAllowedPacket: defaultMaxAllowedPacket,
166+
closech: make(chan struct{}),
166167
}
167168

168169
err := mc.Ping(context.Background())
169170

170-
if err != driver.ErrBadConn {
171-
t.Errorf("expected driver.ErrBadConn, got %#v", err)
171+
if !errors.Is(err, nc.err) {
172+
t.Errorf("expected %v, got %#v", nc.err, err)
172173
}
173174
}
174175

@@ -184,8 +185,8 @@ func TestPingErrInvalidConn(t *testing.T) {
184185

185186
err := mc.Ping(context.Background())
186187

187-
if err != ErrInvalidConn {
188-
t.Errorf("expected ErrInvalidConn, got %#v", err)
188+
if !errors.Is(err, nc.err) {
189+
t.Errorf("expected %v, got %#v", nc.err, err)
189190
}
190191
}
191192

errors.go

-6
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@ var (
2929
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
3030
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`")
3131
ErrBusyBuffer = errors.New("busy buffer")
32-
33-
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
34-
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
35-
// to trigger a resend. Use mc.markBadConn(err) to do this.
36-
// See https://github.com/go-sql-driver/mysql/pull/302
37-
errBadConnNoWrite = errors.New("bad connection")
3832
)
3933

4034
var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime))

packets.go

+29-35
Original file line numberDiff line numberDiff line change
@@ -117,39 +117,33 @@ func (mc *mysqlConn) writePacket(data []byte) error {
117117
// Write packet
118118
if mc.writeTimeout > 0 {
119119
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
120-
mc.cleanup()
121120
mc.log(err)
121+
mc.cleanup()
122122
return err
123123
}
124124
}
125125

126126
n, err := mc.netConn.Write(data[:4+size])
127-
if err == nil && n == 4+size {
128-
mc.sequence++
129-
if size != maxPacketSize {
130-
return nil
131-
}
132-
pktLen -= size
133-
data = data[size:]
134-
continue
135-
}
136-
137-
// Handle error
138-
if err == nil { // n != len(data)
127+
if err != nil {
139128
mc.cleanup()
140-
mc.log(ErrMalformPkt)
141-
} else {
142129
if cerr := mc.canceled.Value(); cerr != nil {
143130
return cerr
144131
}
145-
if n == 0 && pktLen == len(data)-4 {
146-
// only for the first loop iteration when nothing was written yet
147-
return errBadConnNoWrite
148-
}
132+
return err
133+
}
134+
if n != 4+size {
135+
// io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
136+
// The io.ErrShortWrite error is used to indicate that this rule has not been followed.
149137
mc.cleanup()
150-
mc.log(err)
138+
return io.ErrShortWrite
151139
}
152-
return ErrInvalidConn
140+
141+
mc.sequence++
142+
if size != maxPacketSize {
143+
return nil
144+
}
145+
pktLen -= size
146+
data = data[size:]
153147
}
154148
}
155149

@@ -305,8 +299,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
305299
data, err := mc.buf.takeBuffer(pktLen + 4)
306300
if err != nil {
307301
// cannot take the buffer. Something must be wrong with the connection
308-
mc.log(err)
309-
return errBadConnNoWrite
302+
mc.cleanup()
303+
return err
310304
}
311305

312306
// ClientFlags [32 bit]
@@ -394,8 +388,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
394388
data, err := mc.buf.takeSmallBuffer(pktLen)
395389
if err != nil {
396390
// cannot take the buffer. Something must be wrong with the connection
397-
mc.log(err)
398-
return errBadConnNoWrite
391+
mc.cleanup()
392+
return err
399393
}
400394

401395
// Add the auth data [EOF]
@@ -414,8 +408,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
414408
data, err := mc.buf.takeSmallBuffer(4 + 1)
415409
if err != nil {
416410
// cannot take the buffer. Something must be wrong with the connection
417-
mc.log(err)
418-
return errBadConnNoWrite
411+
mc.cleanup()
412+
return err
419413
}
420414

421415
// Add command byte
@@ -433,8 +427,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
433427
data, err := mc.buf.takeBuffer(pktLen + 4)
434428
if err != nil {
435429
// cannot take the buffer. Something must be wrong with the connection
436-
mc.log(err)
437-
return errBadConnNoWrite
430+
mc.cleanup()
431+
return err
438432
}
439433

440434
// Add command byte
@@ -454,8 +448,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
454448
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
455449
if err != nil {
456450
// cannot take the buffer. Something must be wrong with the connection
457-
mc.log(err)
458-
return errBadConnNoWrite
451+
mc.cleanup()
452+
return err
459453
}
460454

461455
// Add command byte
@@ -997,8 +991,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
997991
}
998992
if err != nil {
999993
// cannot take the buffer. Something must be wrong with the connection
1000-
mc.log(err)
1001-
return errBadConnNoWrite
994+
mc.cleanup()
995+
return err
1002996
}
1003997

1004998
// command [1 byte]
@@ -1196,8 +1190,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
11961190
if valuesCap != cap(paramValues) {
11971191
data = append(data[:pos], paramValues...)
11981192
if err = mc.buf.store(data); err != nil {
1199-
mc.log(err)
1200-
return errBadConnNoWrite
1193+
mc.cleanup()
1194+
return err
12011195
}
12021196
}
12031197

statement.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
5656
// Send command
5757
err := stmt.writeExecutePacket(args)
5858
if err != nil {
59-
return nil, stmt.mc.markBadConn(err)
59+
return nil, err
6060
}
6161

6262
mc := stmt.mc
@@ -99,7 +99,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
9999
// Send command
100100
err := stmt.writeExecutePacket(args)
101101
if err != nil {
102-
return nil, stmt.mc.markBadConn(err)
102+
return nil, err
103103
}
104104

105105
mc := stmt.mc

0 commit comments

Comments
 (0)