Skip to content

Commit f25a889

Browse files
Implement GTID tracking.
Co-authored-by: Daniel Joos <[email protected]>
1 parent 00dc21a commit f25a889

File tree

5 files changed

+250
-8
lines changed

5 files changed

+250
-8
lines changed

Diff for: const.go

+9
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,12 @@ const (
188188
cachingSha2PasswordFastAuthSuccess = 3
189189
cachingSha2PasswordPerformFullAuthentication = 4
190190
)
191+
192+
const (
193+
sessionTrackSystemVariables = iota
194+
sessionTrackSchema
195+
sessionTrackStateChange
196+
sessionTrackGtids
197+
sessionTrackTransactionCharacteristics
198+
sessionTrackTransactionState
199+
)

Diff for: driver_test.go

+138
Original file line numberDiff line numberDiff line change
@@ -2602,6 +2602,144 @@ func TestExecMultipleResults(t *testing.T) {
26022602
})
26032603
}
26042604

2605+
func TestGTIDTracking(t *testing.T) {
2606+
ctx := context.Background()
2607+
runTests(t, dsn+"&trackSessionState=true", func(dbt *DBTest) {
2608+
dbt.mustExec(`
2609+
CREATE TABLE test (
2610+
id INT NOT NULL AUTO_INCREMENT,
2611+
value VARCHAR(255),
2612+
PRIMARY KEY (id)
2613+
)`)
2614+
2615+
// Check the current gtid_mode
2616+
var gtidMode string
2617+
if err := dbt.db.QueryRow("SELECT @@global.gtid_mode").Scan(&gtidMode); err != nil {
2618+
t.Fatalf("failed to get gtid_mode: %v", err)
2619+
}
2620+
2621+
if gtidMode == "OFF" {
2622+
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF_PERMISSIVE")
2623+
if err != nil {
2624+
t.Fatalf("failed while trying to change gtid_mode: %v", err)
2625+
}
2626+
defer func() {
2627+
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF")
2628+
if err != nil {
2629+
t.Fatalf("failed while trying to reset gtid_mode: %v", err)
2630+
}
2631+
}()
2632+
2633+
gtidMode = "OFF_PERMISSIVE"
2634+
}
2635+
2636+
if gtidMode == "OFF_PERMISSIVE" {
2637+
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON_PERMISSIVE")
2638+
if err != nil {
2639+
t.Fatalf("failed while trying to change gtid_mode: %v", err)
2640+
}
2641+
defer func() {
2642+
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF_PERMISSIVE")
2643+
if err != nil {
2644+
t.Fatalf("failed while trying to reset gtid_mode: %v", err)
2645+
}
2646+
}()
2647+
2648+
gtidMode = "ON_PERMISSIVE"
2649+
}
2650+
2651+
var enforceGTIDConsistency string
2652+
if err := dbt.db.QueryRow("SELECT @@global.enforce_gtid_consistency").Scan(&enforceGTIDConsistency); err != nil {
2653+
t.Fatalf("failed to get enforce_gtid_consistency: %v", err)
2654+
}
2655+
2656+
if enforceGTIDConsistency == "OFF" {
2657+
_, err := dbt.db.Exec("SET GLOBAL enforce_gtid_consistency = ON")
2658+
if err != nil {
2659+
t.Fatalf("failed while trying to change enforce_gtid_consistency: %v", err)
2660+
}
2661+
defer func() {
2662+
_, err := dbt.db.Exec("SET GLOBAL enforce_gtid_consistency = OFF")
2663+
if err != nil {
2664+
t.Fatalf("failed while trying to reset enforce_gtid_consistency: %v", err)
2665+
}
2666+
}()
2667+
}
2668+
2669+
if gtidMode == "ON_PERMISSIVE" {
2670+
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON")
2671+
if err != nil {
2672+
t.Fatalf("failed while trying to change gtid_mode: %v", err)
2673+
}
2674+
defer func() {
2675+
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON_PERMISSIVE")
2676+
if err != nil {
2677+
t.Fatalf("failed while trying to reset gtid_mode: %v", err)
2678+
}
2679+
}()
2680+
2681+
gtidMode = "ON"
2682+
}
2683+
2684+
conn, err := dbt.db.Conn(ctx)
2685+
if err != nil {
2686+
t.Fatalf("failed to connect: %v", err)
2687+
}
2688+
defer conn.Close()
2689+
2690+
var gtid string
2691+
2692+
conn.Raw(func(conn any) error {
2693+
c := conn.(*mysqlConn)
2694+
2695+
res, err := c.Exec("INSERT INTO test (value) VALUES ('a'), ('b')", nil)
2696+
if err != nil {
2697+
t.Fatal(err)
2698+
}
2699+
2700+
gtid, err = res.(Result).LastGTID()
2701+
if err != nil {
2702+
t.Fatal(err)
2703+
}
2704+
2705+
if gtid != "" {
2706+
t.Fatalf("expected empty gtid, got %v", gtid)
2707+
}
2708+
2709+
_, err = c.Exec("SET SESSION session_track_gtids = ALL_GTIDS", nil)
2710+
if err != nil {
2711+
t.Fatal(err)
2712+
}
2713+
2714+
res, err = c.Exec("INSERT INTO test (value) VALUES ('a'), ('b')", nil)
2715+
if err != nil {
2716+
t.Fatal(err)
2717+
}
2718+
2719+
gtid, err = res.(Result).LastGTID()
2720+
if err != nil {
2721+
t.Fatal(err)
2722+
}
2723+
2724+
if gtid == "" {
2725+
t.Fatal("expected non-empty gtid")
2726+
}
2727+
2728+
return nil
2729+
})
2730+
2731+
var gtidExecuted string
2732+
err = conn.QueryRowContext(ctx, "SELECT @@global.gtid_executed").Scan(&gtidExecuted)
2733+
if err != nil {
2734+
dbt.Fatalf("%s", err.Error())
2735+
}
2736+
2737+
if gtidExecuted != gtid {
2738+
t.Fatalf("expected gtid %v, got %v", gtidExecuted, gtid)
2739+
}
2740+
})
2741+
}
2742+
26052743
// tests if rows are set in a proper state if some results were ignored before
26062744
// calling rows.NextResultSet.
26072745
func TestSkipResults(t *testing.T) {

Diff for: dsn.go

+8
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ type Config struct {
7070
MultiStatements bool // Allow multiple statements in one query
7171
ParseTime bool // Parse time values to time.Time
7272
RejectReadOnly bool // Reject read-only connections
73+
TrackSessionState bool // Enable session state tracking (e.g. GTID values)
7374

7475
// unexported fields. new options should be come here
7576

@@ -581,6 +582,13 @@ func parseDSNParams(cfg *Config, params string) (err error) {
581582
return errors.New("invalid bool value: " + value)
582583
}
583584

585+
case "trackSessionState":
586+
var isBool bool
587+
cfg.TrackSessionState, isBool = readBool(value)
588+
if !isBool {
589+
return errors.New("invalid bool value: " + value)
590+
}
591+
584592
// Server public key
585593
case "serverPubKey":
586594
name, err := url.QueryUnescape(value)

Diff for: packets.go

+78-8
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,24 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
209209

210210
if len(data) > pos {
211211
// character set [1 byte]
212+
// charset := data[pos + 1]
213+
pos += 1
214+
212215
// status flags [2 bytes]
216+
// statusFlags := binary.LittleEndian.Uint16(data[pos : pos + 2])
217+
pos += 2
218+
213219
// capability flags (upper 2 bytes) [2 bytes]
220+
upper := binary.LittleEndian.Uint16(data[pos : pos+2])
221+
pos += 2
222+
223+
mc.flags += clientFlag((uint32(upper) << 16))
224+
214225
// length of auth-plugin-data [1 byte]
215226
// reserved (all [00]) [10 bytes]
216-
pos += 1 + 2 + 2 + 1 + 10
227+
228+
//mc.flags = flags
229+
pos += 1 + 10
217230

218231
// second part of the password cipher [minimum 13 bytes],
219232
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -277,6 +290,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
277290
clientFlags |= clientMultiStatements
278291
}
279292

293+
if mc.cfg.TrackSessionState {
294+
fmt.Println("Setting TrackSessionState")
295+
clientFlags |= clientSessionTrack
296+
}
297+
280298
// encode length of the auth plugin data
281299
var authRespLEIBuf [9]byte
282300
authRespLen := len(authResp)
@@ -530,6 +548,7 @@ func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
530548
// handleOkPacket replaces both values; other cases leave the values unchanged.
531549
mc.result.affectedRows = append(mc.result.affectedRows, 0)
532550
mc.result.insertIds = append(mc.result.insertIds, 0)
551+
mc.result.gtids = append(mc.result.gtids, "")
533552

534553
data, err := mc.conn().readPacket()
535554
if err != nil {
@@ -638,16 +657,20 @@ func (mc *mysqlConn) clearResult() *okHandler {
638657
// Ok Packet
639658
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
640659
func (mc *okHandler) handleOkPacket(data []byte) error {
641-
var n, m int
660+
var offset, length int
661+
642662
var affectedRows, insertId uint64
643663

644664
// 0x00 [1 byte]
665+
offset += 1
645666

646667
// Affected rows [Length Coded Binary]
647-
affectedRows, _, n = readLengthEncodedInteger(data[1:])
668+
affectedRows, _, length = readLengthEncodedInteger(data[offset:])
669+
offset += length
648670

649671
// Insert id [Length Coded Binary]
650-
insertId, _, m = readLengthEncodedInteger(data[1+n:])
672+
insertId, _, length = readLengthEncodedInteger(data[offset:])
673+
offset += length
651674

652675
// Update for the current statement result (only used by
653676
// readResultSetHeaderPacket).
@@ -659,12 +682,59 @@ func (mc *okHandler) handleOkPacket(data []byte) error {
659682
}
660683

661684
// server_status [2 bytes]
662-
mc.status = readStatus(data[1+n+m : 1+n+m+2])
663-
if mc.status&statusMoreResultsExists != 0 {
664-
return nil
665-
}
685+
mc.status = readStatus(data[offset : offset+2])
686+
offset += 2
666687

667688
// warning count [2 bytes]
689+
offset += 2
690+
691+
var gtid string
692+
if (mc.flags & clientSessionTrack) == clientSessionTrack {
693+
// Human readable status information (ignored)
694+
num, _, length := readLengthEncodedInteger(data[offset:])
695+
offset += length
696+
697+
offset += int(num)
698+
699+
if (mc.status & statusSessionStateChanged) == statusSessionStateChanged {
700+
// Length of session state changes
701+
num, _, length = readLengthEncodedInteger(data[offset:])
702+
offset += length
703+
704+
for t := 0; t < int(num); {
705+
infoType := data[offset]
706+
offset += 1
707+
t += 1
708+
709+
if infoType == sessionTrackGtids {
710+
_, _, length := readLengthEncodedInteger(data[offset:])
711+
offset += length
712+
t += length
713+
714+
offset += 1
715+
t += 1
716+
717+
gtidLength, _, length := readLengthEncodedInteger(data[offset:])
718+
offset += length
719+
t += length
720+
721+
gtid = string(data[offset : offset+int(gtidLength)])
722+
723+
offset += int(gtidLength)
724+
t += int(gtidLength)
725+
} else {
726+
// increase the offset to skip the value
727+
valueLength, _, length := readLengthEncodedInteger(data[offset:])
728+
offset += length + int(valueLength)
729+
t += length + int(valueLength)
730+
}
731+
}
732+
}
733+
734+
if len(mc.result.gtids) > 0 {
735+
mc.result.gtids[len(mc.result.gtids)-1] = gtid
736+
}
737+
}
668738

669739
return nil
670740
}

Diff for: result.go

+17
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,17 @@ import "database/sql/driver"
1919
// res.(mysql.Result).AllRowsAffected()
2020
type Result interface {
2121
driver.Result
22+
23+
// LastGTID returns the GTID of the last result, if available.
24+
LastGTID() (string, error)
25+
26+
// AllLastGTIDs returns a slice containing
27+
AllLastGTIDs() []string
28+
2229
// AllRowsAffected returns a slice containing the affected rows for each
2330
// executed statement.
2431
AllRowsAffected() []int64
32+
2533
// AllLastInsertIds returns a slice containing the last inserted ID for each
2634
// executed statement.
2735
AllLastInsertIds() []int64
@@ -31,6 +39,7 @@ type mysqlResult struct {
3139
// One entry in both slices is created for every executed statement result.
3240
affectedRows []int64
3341
insertIds []int64
42+
gtids []string
3443
}
3544

3645
func (res *mysqlResult) LastInsertId() (int64, error) {
@@ -41,10 +50,18 @@ func (res *mysqlResult) RowsAffected() (int64, error) {
4150
return res.affectedRows[len(res.affectedRows)-1], nil
4251
}
4352

53+
func (res *mysqlResult) LastGTID() (string, error) {
54+
return res.gtids[len(res.gtids)-1], nil
55+
}
56+
4457
func (res *mysqlResult) AllLastInsertIds() []int64 {
4558
return append([]int64{}, res.insertIds...) // defensive copy
4659
}
4760

4861
func (res *mysqlResult) AllRowsAffected() []int64 {
4962
return append([]int64{}, res.affectedRows...) // defensive copy
5063
}
64+
65+
func (res *mysqlResult) AllLastGTIDs() []string {
66+
return append([]string{}, res.gtids...)
67+
}

0 commit comments

Comments
 (0)