Skip to content

Commit db0cc0e

Browse files
committed
remove errBadConnNoWrite and markBadConn
1 parent af8d793 commit db0cc0e

File tree

5 files changed

+87
-104
lines changed

5 files changed

+87
-104
lines changed

connection.go

+34-41
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,12 @@ func (mc *mysqlConn) handleParams() (err error) {
9999
return
100100
}
101101

102-
func (mc *mysqlConn) markBadConn(err error) error {
103-
if mc == nil {
104-
return err
105-
}
106-
if err != errBadConnNoWrite {
107-
return err
108-
}
109-
return driver.ErrBadConn
110-
}
111-
112102
func (mc *mysqlConn) Begin() (driver.Tx, error) {
113103
return mc.begin(false)
114104
}
115105

116106
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
117107
if mc.closed.Load() {
118-
mc.log(ErrInvalidConn)
119108
return nil, driver.ErrBadConn
120109
}
121110
var q string
@@ -128,7 +117,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
128117
if err == nil {
129118
return &mysqlTx{mc}, err
130119
}
131-
return nil, mc.markBadConn(err)
120+
return nil, err
132121
}
133122

134123
func (mc *mysqlConn) Close() (err error) {
@@ -177,7 +166,6 @@ func (mc *mysqlConn) error() error {
177166

178167
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
179168
if mc.closed.Load() {
180-
mc.log(ErrInvalidConn)
181169
return nil, driver.ErrBadConn
182170
}
183171
// Send command
@@ -218,8 +206,8 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
218206
buf, err := mc.buf.takeCompleteBuffer()
219207
if err != nil {
220208
// can not take the buffer. Something must be wrong with the connection
221-
mc.log(err)
222-
return "", ErrInvalidConn
209+
mc.cleanup()
210+
return "", err
223211
}
224212
buf = buf[:0]
225213
argPos := 0
@@ -310,7 +298,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
310298

311299
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
312300
if mc.closed.Load() {
313-
mc.log(ErrInvalidConn)
314301
return nil, driver.ErrBadConn
315302
}
316303
if len(args) != 0 {
@@ -330,15 +317,15 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
330317
copied := mc.result
331318
return &copied, err
332319
}
333-
return nil, mc.markBadConn(err)
320+
return nil, err
334321
}
335322

336323
// Internal function to execute commands
337324
func (mc *mysqlConn) exec(query string) error {
338325
handleOk := mc.clearResult()
339326
// Send command
340327
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
341-
return mc.markBadConn(err)
328+
return err
342329
}
343330

344331
// Read Result
@@ -370,7 +357,6 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
370357
handleOk := mc.clearResult()
371358

372359
if mc.closed.Load() {
373-
mc.log(ErrInvalidConn)
374360
return nil, driver.ErrBadConn
375361
}
376362
if len(args) != 0 {
@@ -384,33 +370,37 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
384370
}
385371
query = prepared
386372
}
373+
387374
// Send command
388375
err := mc.writeCommandPacketStr(comQuery, query)
389-
if err == nil {
390-
// Read Result
391-
var resLen int
392-
resLen, err = handleOk.readResultSetHeaderPacket()
393-
if err == nil {
394-
rows := new(textRows)
395-
rows.mc = mc
376+
if err != nil {
377+
return nil, err
378+
}
396379

397-
if resLen == 0 {
398-
rows.rs.done = true
380+
// Read Result
381+
var resLen int
382+
resLen, err = handleOk.readResultSetHeaderPacket()
383+
if err != nil {
384+
return nil, err
385+
}
399386

400-
switch err := rows.NextResultSet(); err {
401-
case nil, io.EOF:
402-
return rows, nil
403-
default:
404-
return nil, err
405-
}
406-
}
387+
rows := new(textRows)
388+
rows.mc = mc
407389

408-
// Columns
409-
rows.rs.columns, err = mc.readColumns(resLen)
410-
return rows, err
390+
if resLen == 0 {
391+
rows.rs.done = true
392+
393+
switch err := rows.NextResultSet(); err {
394+
case nil, io.EOF:
395+
return rows, nil
396+
default:
397+
return nil, err
411398
}
412399
}
413-
return nil, mc.markBadConn(err)
400+
401+
// Columns
402+
rows.rs.columns, err = mc.readColumns(resLen)
403+
return rows, err
414404
}
415405

416406
// Gets the value of the given MySQL System Variable
@@ -465,7 +455,6 @@ func (mc *mysqlConn) finish() {
465455
// Ping implements driver.Pinger interface
466456
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
467457
if mc.closed.Load() {
468-
mc.log(ErrInvalidConn)
469458
return driver.ErrBadConn
470459
}
471460

@@ -476,7 +465,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
476465

477466
handleOk := mc.clearResult()
478467
if err = mc.writeCommandPacket(comPing); err != nil {
479-
return mc.markBadConn(err)
468+
return err
480469
}
481470

482471
return handleOk.readResultOK()
@@ -682,8 +671,12 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
682671
return nil
683672
}
684673

674+
var _ driver.SessionResetter = &mysqlConn{}
675+
685676
// IsValid implements driver.Validator interface
686677
// (From Go 1.15)
687678
func (mc *mysqlConn) IsValid() bool {
688679
return !mc.closed.Load()
689680
}
681+
682+
var _ driver.Validator = &mysqlConn{}

connection_test.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,13 @@ func TestPingMarkBadConnection(t *testing.T) {
163163
netConn: nc,
164164
buf: newBuffer(nc),
165165
maxAllowedPacket: defaultMaxAllowedPacket,
166+
closech: make(chan struct{}),
166167
}
167168

168169
err := mc.Ping(context.Background())
169170

170-
if err != driver.ErrBadConn {
171-
t.Errorf("expected driver.ErrBadConn, got %#v", err)
171+
if !errors.Is(err, nc.err) {
172+
t.Errorf("expected %v, got %#v", nc.err, err)
172173
}
173174
}
174175

@@ -184,8 +185,8 @@ func TestPingErrInvalidConn(t *testing.T) {
184185

185186
err := mc.Ping(context.Background())
186187

187-
if err != ErrInvalidConn {
188-
t.Errorf("expected ErrInvalidConn, got %#v", err)
188+
if !errors.Is(err, nc.err) {
189+
t.Errorf("expected %v, got %#v", nc.err, err)
189190
}
190191
}
191192

errors.go

-6
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@ var (
2929
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
3030
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`")
3131
ErrBusyBuffer = errors.New("busy buffer")
32-
33-
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
34-
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
35-
// to trigger a resend.
36-
// See https://github.com/go-sql-driver/mysql/pull/302
37-
errBadConnNoWrite = errors.New("bad connection")
3832
)
3933

4034
var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))

packets.go

+46-50
Original file line numberDiff line numberDiff line change
@@ -117,37 +117,32 @@ func (mc *mysqlConn) writePacket(data []byte) error {
117117
// Write packet
118118
if mc.writeTimeout > 0 {
119119
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
120+
mc.log(err)
121+
mc.cleanup()
120122
return err
121123
}
122124
}
123125

124126
n, err := mc.netConn.Write(data[:4+size])
125-
if err == nil && n == 4+size {
126-
mc.sequence++
127-
if size != maxPacketSize {
128-
return nil
129-
}
130-
pktLen -= size
131-
data = data[size:]
132-
continue
133-
}
134-
135-
// Handle error
136-
if err == nil { // n != len(data)
127+
if err != nil {
137128
mc.cleanup()
138-
mc.log(ErrMalformPkt)
139-
} else {
140129
if cerr := mc.canceled.Value(); cerr != nil {
141130
return cerr
142131
}
143-
if n == 0 && pktLen == len(data)-4 {
144-
// only for the first loop iteration when nothing was written yet
145-
return errBadConnNoWrite
146-
}
132+
return err
133+
}
134+
if n != size+4 {
147135
mc.cleanup()
148-
mc.log(err)
136+
return io.ErrShortWrite
137+
}
138+
139+
mc.sequence++
140+
if size != maxPacketSize {
141+
return nil
149142
}
150-
return ErrInvalidConn
143+
pktLen -= size
144+
data = data[size:]
145+
continue
151146
}
152147
}
153148

@@ -303,8 +298,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
303298
data, err := mc.buf.takeBuffer(pktLen + 4)
304299
if err != nil {
305300
// cannot take the buffer. Something must be wrong with the connection
306-
mc.log(err)
307-
return errBadConnNoWrite
301+
mc.cleanup()
302+
return err
308303
}
309304

310305
// ClientFlags [32 bit]
@@ -392,8 +387,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
392387
data, err := mc.buf.takeSmallBuffer(pktLen)
393388
if err != nil {
394389
// cannot take the buffer. Something must be wrong with the connection
395-
mc.log(err)
396-
return errBadConnNoWrite
390+
mc.cleanup()
391+
return err
397392
}
398393

399394
// Add the auth data [EOF]
@@ -412,8 +407,8 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
412407
data, err := mc.buf.takeSmallBuffer(4 + 1)
413408
if err != nil {
414409
// cannot take the buffer. Something must be wrong with the connection
415-
mc.log(err)
416-
return errBadConnNoWrite
410+
mc.cleanup()
411+
return err
417412
}
418413

419414
// Add command byte
@@ -431,8 +426,8 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
431426
data, err := mc.buf.takeBuffer(pktLen + 4)
432427
if err != nil {
433428
// cannot take the buffer. Something must be wrong with the connection
434-
mc.log(err)
435-
return errBadConnNoWrite
429+
mc.cleanup()
430+
return err
436431
}
437432

438433
// Add command byte
@@ -452,8 +447,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
452447
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
453448
if err != nil {
454449
// cannot take the buffer. Something must be wrong with the connection
455-
mc.log(err)
456-
return errBadConnNoWrite
450+
mc.cleanup()
451+
return err
457452
}
458453

459454
// Add command byte
@@ -522,32 +517,33 @@ func (mc *okHandler) readResultOK() error {
522517
}
523518

524519
// Result Set Header Packet
525-
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
520+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
526521
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
527522
// handleOkPacket replaces both values; other cases leave the values unchanged.
528523
mc.result.affectedRows = append(mc.result.affectedRows, 0)
529524
mc.result.insertIds = append(mc.result.insertIds, 0)
530525

531526
data, err := mc.conn().readPacket()
532-
if err == nil {
533-
switch data[0] {
534-
535-
case iOK:
536-
return 0, mc.handleOkPacket(data)
527+
if err != nil {
528+
return 0, err
529+
}
537530

538-
case iERR:
539-
return 0, mc.conn().handleErrorPacket(data)
531+
switch data[0] {
532+
case iOK:
533+
return 0, mc.handleOkPacket(data)
540534

541-
case iLocalInFile:
542-
return 0, mc.handleInFileRequest(string(data[1:]))
543-
}
535+
case iERR:
536+
return 0, mc.conn().handleErrorPacket(data)
544537

545-
// column count
546-
num, _, _ := readLengthEncodedInteger(data)
547-
// ignore remaining data in the packet. see #1478.
548-
return int(num), nil
538+
case iLocalInFile:
539+
return 0, mc.handleInFileRequest(string(data[1:]))
549540
}
550-
return 0, err
541+
542+
// column count
543+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
544+
num, _, _ := readLengthEncodedInteger(data)
545+
// ignore remaining data in the packet. see #1478.
546+
return int(num), nil
551547
}
552548

553549
// Error Packet
@@ -994,8 +990,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
994990
}
995991
if err != nil {
996992
// cannot take the buffer. Something must be wrong with the connection
997-
mc.log(err)
998-
return errBadConnNoWrite
993+
mc.cleanup()
994+
return err
999995
}
1000996

1001997
// command [1 byte]
@@ -1193,8 +1189,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
11931189
if valuesCap != cap(paramValues) {
11941190
data = append(data[:pos], paramValues...)
11951191
if err = mc.buf.store(data); err != nil {
1196-
mc.log(err)
1197-
return errBadConnNoWrite
1192+
mc.cleanup()
1193+
return err
11981194
}
11991195
}
12001196

0 commit comments

Comments
 (0)