Skip to content

Commit e1dc557

Browse files
committed
wip
1 parent 77d86ec commit e1dc557

File tree

5 files changed

+52
-20
lines changed

5 files changed

+52
-20
lines changed

auth.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
338338
return authEd25519(authData, mc.cfg.Passwd)
339339

340340
default:
341-
mc.cfg.Logger.Print("unknown auth plugin:", plugin)
341+
mc.log("unknown auth plugin:", plugin)
342342
return nil, ErrUnknownPlugin
343343
}
344344
}

compress.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,18 @@ func (c *decompressor) uncompressPacket() error {
123123
uncompressedLength := int(uint32(header[4]) | uint32(header[5])<<8 | uint32(header[6])<<16)
124124
compressionSequence := uint8(header[3])
125125
if debugTrace {
126-
c.mc.cfg.Logger.Print(
127-
fmt.Sprintf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n",
128-
comprLength, uncompressedLength, compressionSequence, c.mc.sequence))
126+
traceLogger.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n",
127+
comprLength, uncompressedLength, compressionSequence, c.mc.sequence)
129128
}
130129
if compressionSequence != c.mc.sequence {
131130
// return ErrPktSync
132131
// server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
133132
// before receiving all packets from client. In this case, seqnr is younger than expected.
134-
c.mc.cfg.Logger.Print(
135-
fmt.Sprintf("[warn] unexpected cmpress seq nr: expected %v, got %v",
136-
c.mc.sequence, compressionSequence))
133+
if debugTrace {
134+
traceLogger.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v",
135+
c.mc.sequence, compressionSequence)
136+
}
137+
c.mc.invalid = true
137138
}
138139
c.mc.sequence = compressionSequence + 1
139140
c.mc.compressSequence = c.mc.sequence
@@ -218,10 +219,9 @@ func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) {
218219
func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error {
219220
comprLength := len(data) - 7
220221
if debugTrace {
221-
mc.cfg.Logger.Print(
222-
fmt.Sprintf(
223-
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
224-
comprLength, uncompressedLen, mc.compressSequence))
222+
traceLogger.Printf(
223+
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
224+
comprLength, uncompressedLen, mc.compressSequence)
225225
}
226226

227227
// compression header
@@ -237,7 +237,7 @@ func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) err
237237
data[6] = byte(0xff & (uncompressedLen >> 16))
238238

239239
if _, err := mc.netConn.Write(data); err != nil {
240-
mc.cfg.Logger.Print(err)
240+
mc.log("writing compressed packet:", err)
241241
return err
242242
}
243243

connection.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type mysqlConn struct {
3737
compressSequence uint8
3838
parseTime bool
3939
compress bool
40+
invalid bool // true when the connection is in invalid state and will be closed later.
4041

4142
// for context support (Go 1.8+)
4243
watching bool
@@ -132,7 +133,6 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
132133

133134
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
134135
if mc.closed.Load() {
135-
mc.cfg.Logger.Print(ErrInvalidConn)
136136
return nil, driver.ErrBadConn
137137
}
138138
var q string
@@ -173,7 +173,7 @@ func (mc *mysqlConn) cleanup() {
173173
return
174174
}
175175
if err := mc.netConn.Close(); err != nil {
176-
mc.cfg.Logger.Print(err)
176+
mc.log("closing connection:", err)
177177
}
178178
mc.clearResult()
179179
}
@@ -698,5 +698,5 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
698698
// IsValid implements driver.Validator interface
699699
// (From Go 1.15)
700700
func (mc *mysqlConn) IsValid() bool {
701-
return !mc.closed.Load()
701+
return !mc.closed.Load() && !mc.invalid
702702
}

errors.go

+26-3
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,41 @@ var (
3737
errBadConnNoWrite = errors.New("bad connection")
3838
)
3939

40-
var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
40+
// traceLogger is used for debug trace log.
41+
var traceLogger *log.Logger
42+
43+
func init() {
44+
if debugTrace {
45+
traceLogger = log.New(os.Stderr, "[mysql.trace]", log.Lmicroseconds|log.Lshortfile)
46+
}
47+
}
48+
49+
func trace(format string, v ...any) {
50+
if debugTrace {
51+
traceLogger.Printf(format, v...)
52+
}
53+
}
4154

4255
// Logger is used to log critical error messages.
4356
type Logger interface {
44-
Print(v ...interface{})
57+
Print(v ...any)
58+
}
59+
60+
var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
61+
62+
func (mc *mysqlConn) log(v ...any) {
63+
mc.cfg.Logger.Print(v...)
64+
}
65+
66+
func (mc *mysqlConn) logf(format string, v ...any) {
67+
mc.cfg.Logger.Print(fmt.Sprintf(format, v...))
4568
}
4669

4770
// NopLogger is a nop implementation of the Logger interface.
4871
type NopLogger struct{}
4972

5073
// Print implements Logger interface.
51-
func (nl *NopLogger) Print(_ ...interface{}) {}
74+
func (nl *NopLogger) Print(_ ...any) {}
5275

5376
// SetLogger is used to set the default logger for critical errors.
5477
// The initial logger is os.Stderr.

packets.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
// Read packet to buffer 'data'
2929
func (mc *mysqlConn) readPacket() ([]byte, error) {
3030
var prevData []byte
31+
var rerr error = nil
3132
for {
3233
// read packet header
3334
data, err := mc.packetReader.readNext(4)
@@ -46,10 +47,18 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
4647
// packet length [24 bit]
4748
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
4849

49-
if !mc.compress { // MySQL and MariaDB doesn't check packet nr in compressed packet.
50+
if mc.compress {
51+
// MySQL and MariaDB doesn't check packet nr in compressed packet.
52+
if debugTrace && data[3] != mc.compressSequence {
53+
mc.cfg.Logger.Print
54+
}
55+
mc.compressSequence = data[3]+1
56+
} else mc.compress {
5057
// check packet sync [8 bit]
5158
if data[3] != mc.sequence {
5259
mc.cfg.Logger.Print(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, data[3]))
60+
mc.invalid = true
61+
rerr = ErrInvalidConn
5362
}
5463
mc.sequence++
5564
}
@@ -117,7 +126,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
117126

118127
// Write packet
119128
if debugTrace {
120-
mc.cfg.Logger.Print(fmt.Sprintf("writePacket: size=%v seq=%v", size, mc.sequence))
129+
traceLogger.Printf("writePacket: size=%v seq=%v", size, mc.sequence)
121130
}
122131
if mc.writeTimeout > 0 {
123132
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {

0 commit comments

Comments
 (0)