From f25a889b6d1dcac1023cfcc5fd204bd32cd37f88 Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Wed, 2 Oct 2024 07:54:08 +0000 Subject: [PATCH 1/3] Implement GTID tracking. Co-authored-by: Daniel Joos --- const.go | 9 ++++ driver_test.go | 138 +++++++++++++++++++++++++++++++++++++++++++++++++ dsn.go | 8 +++ packets.go | 86 +++++++++++++++++++++++++++--- result.go | 17 ++++++ 5 files changed, 250 insertions(+), 8 deletions(-) diff --git a/const.go b/const.go index 0cee9b2ee..c51830eb4 100644 --- a/const.go +++ b/const.go @@ -188,3 +188,12 @@ const ( cachingSha2PasswordFastAuthSuccess = 3 cachingSha2PasswordPerformFullAuthentication = 4 ) + +const ( + sessionTrackSystemVariables = iota + sessionTrackSchema + sessionTrackStateChange + sessionTrackGtids + sessionTrackTransactionCharacteristics + sessionTrackTransactionState +) diff --git a/driver_test.go b/driver_test.go index 24d73c34f..d0602380a 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2602,6 +2602,144 @@ func TestExecMultipleResults(t *testing.T) { }) } +func TestGTIDTracking(t *testing.T) { + ctx := context.Background() + runTests(t, dsn+"&trackSessionState=true", func(dbt *DBTest) { + dbt.mustExec(` + CREATE TABLE test ( + id INT NOT NULL AUTO_INCREMENT, + value VARCHAR(255), + PRIMARY KEY (id) + )`) + + // Check the current gtid_mode + var gtidMode string + if err := dbt.db.QueryRow("SELECT @@global.gtid_mode").Scan(>idMode); err != nil { + t.Fatalf("failed to get gtid_mode: %v", err) + } + + if gtidMode == "OFF" { + _, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF_PERMISSIVE") + if err != nil { + t.Fatalf("failed while trying to change gtid_mode: %v", err) + } + defer func() { + _, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF") + if err != nil { + t.Fatalf("failed while trying to reset gtid_mode: %v", err) + } + }() + + gtidMode = "OFF_PERMISSIVE" + } + + if gtidMode == "OFF_PERMISSIVE" { + _, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON_PERMISSIVE") + if err != nil { + t.Fatalf("failed while trying to change gtid_mode: %v", err) + } + defer func() { + _, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF_PERMISSIVE") + if err != nil { + t.Fatalf("failed while trying to reset gtid_mode: %v", err) + } + }() + + gtidMode = "ON_PERMISSIVE" + } + + var enforceGTIDConsistency string + if err := dbt.db.QueryRow("SELECT @@global.enforce_gtid_consistency").Scan(&enforceGTIDConsistency); err != nil { + t.Fatalf("failed to get enforce_gtid_consistency: %v", err) + } + + if enforceGTIDConsistency == "OFF" { + _, err := dbt.db.Exec("SET GLOBAL enforce_gtid_consistency = ON") + if err != nil { + t.Fatalf("failed while trying to change enforce_gtid_consistency: %v", err) + } + defer func() { + _, err := dbt.db.Exec("SET GLOBAL enforce_gtid_consistency = OFF") + if err != nil { + t.Fatalf("failed while trying to reset enforce_gtid_consistency: %v", err) + } + }() + } + + if gtidMode == "ON_PERMISSIVE" { + _, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON") + if err != nil { + t.Fatalf("failed while trying to change gtid_mode: %v", err) + } + defer func() { + _, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON_PERMISSIVE") + if err != nil { + t.Fatalf("failed while trying to reset gtid_mode: %v", err) + } + }() + + gtidMode = "ON" + } + + conn, err := dbt.db.Conn(ctx) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close() + + var gtid string + + conn.Raw(func(conn any) error { + c := conn.(*mysqlConn) + + res, err := c.Exec("INSERT INTO test (value) VALUES ('a'), ('b')", nil) + if err != nil { + t.Fatal(err) + } + + gtid, err = res.(Result).LastGTID() + if err != nil { + t.Fatal(err) + } + + if gtid != "" { + t.Fatalf("expected empty gtid, got %v", gtid) + } + + _, err = c.Exec("SET SESSION session_track_gtids = ALL_GTIDS", nil) + if err != nil { + t.Fatal(err) + } + + res, err = c.Exec("INSERT INTO test (value) VALUES ('a'), ('b')", nil) + if err != nil { + t.Fatal(err) + } + + gtid, err = res.(Result).LastGTID() + if err != nil { + t.Fatal(err) + } + + if gtid == "" { + t.Fatal("expected non-empty gtid") + } + + return nil + }) + + var gtidExecuted string + err = conn.QueryRowContext(ctx, "SELECT @@global.gtid_executed").Scan(>idExecuted) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + if gtidExecuted != gtid { + t.Fatalf("expected gtid %v, got %v", gtidExecuted, gtid) + } + }) +} + // tests if rows are set in a proper state if some results were ignored before // calling rows.NextResultSet. func TestSkipResults(t *testing.T) { diff --git a/dsn.go b/dsn.go index 3c7a6e215..2304242a4 100644 --- a/dsn.go +++ b/dsn.go @@ -70,6 +70,7 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections + TrackSessionState bool // Enable session state tracking (e.g. GTID values) // unexported fields. new options should be come here @@ -581,6 +582,13 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + case "trackSessionState": + var isBool bool + cfg.TrackSessionState, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + // Server public key case "serverPubKey": name, err := url.QueryUnescape(value) diff --git a/packets.go b/packets.go index 014a1deee..0ad5d00a0 100644 --- a/packets.go +++ b/packets.go @@ -209,11 +209,24 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if len(data) > pos { // character set [1 byte] + // charset := data[pos + 1] + pos += 1 + // status flags [2 bytes] + // statusFlags := binary.LittleEndian.Uint16(data[pos : pos + 2]) + pos += 2 + // capability flags (upper 2 bytes) [2 bytes] + upper := binary.LittleEndian.Uint16(data[pos : pos+2]) + pos += 2 + + mc.flags += clientFlag((uint32(upper) << 16)) + // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + + //mc.flags = flags + pos += 1 + 10 // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -277,6 +290,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientFlags |= clientMultiStatements } + if mc.cfg.TrackSessionState { + fmt.Println("Setting TrackSessionState") + clientFlags |= clientSessionTrack + } + // encode length of the auth plugin data var authRespLEIBuf [9]byte authRespLen := len(authResp) @@ -530,6 +548,7 @@ func (mc *okHandler) readResultSetHeaderPacket() (int, error) { // handleOkPacket replaces both values; other cases leave the values unchanged. mc.result.affectedRows = append(mc.result.affectedRows, 0) mc.result.insertIds = append(mc.result.insertIds, 0) + mc.result.gtids = append(mc.result.gtids, "") data, err := mc.conn().readPacket() if err != nil { @@ -638,16 +657,20 @@ func (mc *mysqlConn) clearResult() *okHandler { // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *okHandler) handleOkPacket(data []byte) error { - var n, m int + var offset, length int + var affectedRows, insertId uint64 // 0x00 [1 byte] + offset += 1 // Affected rows [Length Coded Binary] - affectedRows, _, n = readLengthEncodedInteger(data[1:]) + affectedRows, _, length = readLengthEncodedInteger(data[offset:]) + offset += length // Insert id [Length Coded Binary] - insertId, _, m = readLengthEncodedInteger(data[1+n:]) + insertId, _, length = readLengthEncodedInteger(data[offset:]) + offset += length // Update for the current statement result (only used by // readResultSetHeaderPacket). @@ -659,12 +682,59 @@ func (mc *okHandler) handleOkPacket(data []byte) error { } // server_status [2 bytes] - mc.status = readStatus(data[1+n+m : 1+n+m+2]) - if mc.status&statusMoreResultsExists != 0 { - return nil - } + mc.status = readStatus(data[offset : offset+2]) + offset += 2 // warning count [2 bytes] + offset += 2 + + var gtid string + if (mc.flags & clientSessionTrack) == clientSessionTrack { + // Human readable status information (ignored) + num, _, length := readLengthEncodedInteger(data[offset:]) + offset += length + + offset += int(num) + + if (mc.status & statusSessionStateChanged) == statusSessionStateChanged { + // Length of session state changes + num, _, length = readLengthEncodedInteger(data[offset:]) + offset += length + + for t := 0; t < int(num); { + infoType := data[offset] + offset += 1 + t += 1 + + if infoType == sessionTrackGtids { + _, _, length := readLengthEncodedInteger(data[offset:]) + offset += length + t += length + + offset += 1 + t += 1 + + gtidLength, _, length := readLengthEncodedInteger(data[offset:]) + offset += length + t += length + + gtid = string(data[offset : offset+int(gtidLength)]) + + offset += int(gtidLength) + t += int(gtidLength) + } else { + // increase the offset to skip the value + valueLength, _, length := readLengthEncodedInteger(data[offset:]) + offset += length + int(valueLength) + t += length + int(valueLength) + } + } + } + + if len(mc.result.gtids) > 0 { + mc.result.gtids[len(mc.result.gtids)-1] = gtid + } + } return nil } diff --git a/result.go b/result.go index d51631468..5edc57877 100644 --- a/result.go +++ b/result.go @@ -19,9 +19,17 @@ import "database/sql/driver" // res.(mysql.Result).AllRowsAffected() type Result interface { driver.Result + + // LastGTID returns the GTID of the last result, if available. + LastGTID() (string, error) + + // AllLastGTIDs returns a slice containing + AllLastGTIDs() []string + // AllRowsAffected returns a slice containing the affected rows for each // executed statement. AllRowsAffected() []int64 + // AllLastInsertIds returns a slice containing the last inserted ID for each // executed statement. AllLastInsertIds() []int64 @@ -31,6 +39,7 @@ type mysqlResult struct { // One entry in both slices is created for every executed statement result. affectedRows []int64 insertIds []int64 + gtids []string } func (res *mysqlResult) LastInsertId() (int64, error) { @@ -41,6 +50,10 @@ func (res *mysqlResult) RowsAffected() (int64, error) { return res.affectedRows[len(res.affectedRows)-1], nil } +func (res *mysqlResult) LastGTID() (string, error) { + return res.gtids[len(res.gtids)-1], nil +} + func (res *mysqlResult) AllLastInsertIds() []int64 { return append([]int64{}, res.insertIds...) // defensive copy } @@ -48,3 +61,7 @@ func (res *mysqlResult) AllLastInsertIds() []int64 { func (res *mysqlResult) AllRowsAffected() []int64 { return append([]int64{}, res.affectedRows...) // defensive copy } + +func (res *mysqlResult) AllLastGTIDs() []string { + return append([]string{}, res.gtids...) +} From 987ee65080508d263ad232900d12b1f0c70f53f8 Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Wed, 2 Oct 2024 08:05:20 +0000 Subject: [PATCH 2/3] Add missing support for `trackSessionState` when generating dsn strings. --- dsn.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dsn.go b/dsn.go index 2304242a4..503f7388a 100644 --- a/dsn.go +++ b/dsn.go @@ -324,6 +324,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true") } + if cfg.TrackSessionState { + writeDSNParam(&buf, &hasParam, "trackSessionState", "true") + } + if len(cfg.ServerPubKey) > 0 { writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey)) } From 55420ba60138543c44f9c3bd2025299c977555ff Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Wed, 2 Oct 2024 08:05:47 +0000 Subject: [PATCH 3/3] Remove stray log output. --- packets.go | 1 - 1 file changed, 1 deletion(-) diff --git a/packets.go b/packets.go index 0ad5d00a0..8f3de4593 100644 --- a/packets.go +++ b/packets.go @@ -291,7 +291,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } if mc.cfg.TrackSessionState { - fmt.Println("Setting TrackSessionState") clientFlags |= clientSessionTrack }