Skip to content

Commit a343408

Browse files
committed
send client connect attrs according to whether mysql server supports CLIENT_CONNECT_ATTRS
1 parent f62f523 commit a343408

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

Diff for: packets.go

+17-8
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,12 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
210210
if len(data) > pos {
211211
// character set [1 byte]
212212
// status flags [2 bytes]
213+
pos += 1 + 2
213214
// capability flags (upper 2 bytes) [2 bytes]
215+
mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
214216
// length of auth-plugin-data [1 byte]
215217
// reserved (all [00]) [10 bytes]
216-
pos += 1 + 2 + 2 + 1 + 10
218+
pos += 2 + 1 + 10
217219

218220
// second part of the password cipher [minimum 13 bytes],
219221
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -261,9 +263,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
261263
clientLocalFiles |
262264
clientPluginAuth |
263265
clientMultiResults |
264-
clientConnectAttrs |
266+
mc.flags&clientConnectAttrs |
265267
mc.flags&clientLongFlag
266268

269+
serverSupportClientConnectAttrs := mc.flags&clientConnectAttrs != 0
270+
267271
if mc.cfg.ClientFoundRows {
268272
clientFlags |= clientFoundRows
269273
}
@@ -295,11 +299,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
295299
pktLen += n + 1
296300
}
297301

302+
var connAttrsLEI []byte
298303
// encode length of the connection attributes
299-
var connAttrsLEIBuf [9]byte
300-
connAttrsLen := len(mc.connector.encodedAttributes)
301-
connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
302-
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
304+
if serverSupportClientConnectAttrs {
305+
var connAttrsLEIBuf [9]byte
306+
connAttrsLen := len(mc.connector.encodedAttributes)
307+
connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
308+
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
309+
}
303310

304311
// Calculate packet length and get buffer with that size
305312
data, err := mc.buf.takeBuffer(pktLen + 4)
@@ -382,8 +389,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
382389
pos++
383390

384391
// Connection Attributes
385-
pos += copy(data[pos:], connAttrsLEI)
386-
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
392+
if serverSupportClientConnectAttrs {
393+
pos += copy(data[pos:], connAttrsLEI)
394+
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
395+
}
387396

388397
// Send Auth packet
389398
return mc.writePacket(data[:pos])

0 commit comments

Comments
 (0)