Skip to content

Implement GTID tracking. #1633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,12 @@ const (
cachingSha2PasswordFastAuthSuccess = 3
cachingSha2PasswordPerformFullAuthentication = 4
)

const (
sessionTrackSystemVariables = iota
sessionTrackSchema
sessionTrackStateChange
sessionTrackGtids
sessionTrackTransactionCharacteristics
sessionTrackTransactionState
)
138 changes: 138 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2602,6 +2602,144 @@ func TestExecMultipleResults(t *testing.T) {
})
}

func TestGTIDTracking(t *testing.T) {
ctx := context.Background()
runTests(t, dsn+"&trackSessionState=true", func(dbt *DBTest) {
dbt.mustExec(`
CREATE TABLE test (
id INT NOT NULL AUTO_INCREMENT,
value VARCHAR(255),
PRIMARY KEY (id)
)`)

// Check the current gtid_mode
var gtidMode string
if err := dbt.db.QueryRow("SELECT @@global.gtid_mode").Scan(&gtidMode); err != nil {
t.Fatalf("failed to get gtid_mode: %v", err)
}

if gtidMode == "OFF" {
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF_PERMISSIVE")
if err != nil {
t.Fatalf("failed while trying to change gtid_mode: %v", err)
}
defer func() {
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF")
if err != nil {
t.Fatalf("failed while trying to reset gtid_mode: %v", err)
}
}()
Comment on lines +2627 to +2631
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid calling t.Fatalf inside deferred functions

Using t.Fatalf inside a deferred function can lead to unexpected behavior because it may not stop the test execution as intended. Instead, use t.Errorf to report the error and ensure the test fails by checking for errors after the deferred function executes.

Suggested fix:

 defer func() {
     _, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF")
     if err != nil {
-        t.Fatalf("failed while trying to reset gtid_mode: %v", err)
+        t.Errorf("failed while trying to reset gtid_mode: %v", err)
     }
 }()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF")
if err != nil {
t.Fatalf("failed while trying to reset gtid_mode: %v", err)
}
}()
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF")
if err != nil {
t.Errorf("failed while trying to reset gtid_mode: %v", err)
}
}()


gtidMode = "OFF_PERMISSIVE"
}

if gtidMode == "OFF_PERMISSIVE" {
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON_PERMISSIVE")
if err != nil {
t.Fatalf("failed while trying to change gtid_mode: %v", err)
}
defer func() {
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF_PERMISSIVE")
if err != nil {
t.Fatalf("failed while trying to reset gtid_mode: %v", err)
}
}()
Comment on lines +2642 to +2646
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid calling t.Fatalf inside deferred functions

Similar to the previous instance, replace t.Fatalf with t.Errorf inside deferred functions to prevent improper test termination.

Suggested fix:

 defer func() {
     _, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF_PERMISSIVE")
     if err != nil {
-        t.Fatalf("failed while trying to reset gtid_mode: %v", err)
+        t.Errorf("failed while trying to reset gtid_mode: %v", err)
     }
 }()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF_PERMISSIVE")
if err != nil {
t.Fatalf("failed while trying to reset gtid_mode: %v", err)
}
}()
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = OFF_PERMISSIVE")
if err != nil {
t.Errorf("failed while trying to reset gtid_mode: %v", err)
}
}()


gtidMode = "ON_PERMISSIVE"
}

var enforceGTIDConsistency string
if err := dbt.db.QueryRow("SELECT @@global.enforce_gtid_consistency").Scan(&enforceGTIDConsistency); err != nil {
t.Fatalf("failed to get enforce_gtid_consistency: %v", err)
}

if enforceGTIDConsistency == "OFF" {
_, err := dbt.db.Exec("SET GLOBAL enforce_gtid_consistency = ON")
if err != nil {
t.Fatalf("failed while trying to change enforce_gtid_consistency: %v", err)
}
defer func() {
_, err := dbt.db.Exec("SET GLOBAL enforce_gtid_consistency = OFF")
if err != nil {
t.Fatalf("failed while trying to reset enforce_gtid_consistency: %v", err)
}
}()
}

if gtidMode == "ON_PERMISSIVE" {
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON")
if err != nil {
t.Fatalf("failed while trying to change gtid_mode: %v", err)
}
defer func() {
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON_PERMISSIVE")
if err != nil {
t.Fatalf("failed while trying to reset gtid_mode: %v", err)
}
}()
Comment on lines +2675 to +2679
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid calling t.Fatalf inside deferred functions

Using t.Fatalf in deferred functions can cause interference with the test flow. Replace it with t.Errorf and handle the error appropriately after the deferred function.

Suggested fix:

 defer func() {
     _, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON_PERMISSIVE")
     if err != nil {
-        t.Fatalf("failed while trying to reset gtid_mode: %v", err)
+        t.Errorf("failed while trying to reset gtid_mode: %v", err)
     }
 }()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON_PERMISSIVE")
if err != nil {
t.Fatalf("failed while trying to reset gtid_mode: %v", err)
}
}()
_, err := dbt.db.Exec("SET GLOBAL gtid_mode = ON_PERMISSIVE")
if err != nil {
t.Errorf("failed while trying to reset gtid_mode: %v", err)
}
}()


gtidMode = "ON"
}

conn, err := dbt.db.Conn(ctx)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer conn.Close()

var gtid string

conn.Raw(func(conn any) error {
c := conn.(*mysqlConn)

Comment on lines +2693 to +2694
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Check type assertion when casting to *mysqlConn

The type assertion conn.(*mysqlConn) may panic if conn is not of the expected type. It's safer to check the assertion and handle the error gracefully.

Suggested fix:

-        c := conn.(*mysqlConn)
+        c, ok := conn.(*mysqlConn)
+        if !ok {
+            return fmt.Errorf("expected *mysqlConn, got %T", conn)
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
c := conn.(*mysqlConn)
c, ok := conn.(*mysqlConn)
if !ok {
return fmt.Errorf("expected *mysqlConn, got %T", conn)
}

res, err := c.Exec("INSERT INTO test (value) VALUES ('a'), ('b')", nil)
if err != nil {
t.Fatal(err)
}

gtid, err = res.(Result).LastGTID()
if err != nil {
t.Fatal(err)
}

if gtid != "" {
t.Fatalf("expected empty gtid, got %v", gtid)
}

_, err = c.Exec("SET SESSION session_track_gtids = ALL_GTIDS", nil)
if err != nil {
t.Fatal(err)
}

res, err = c.Exec("INSERT INTO test (value) VALUES ('a'), ('b')", nil)
if err != nil {
t.Fatal(err)
}

gtid, err = res.(Result).LastGTID()
if err != nil {
t.Fatal(err)
}

if gtid == "" {
t.Fatal("expected non-empty gtid")
}
Comment on lines +2697 to +2726
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid calling t.Fatal inside conn.Raw callbacks

Calling t.Fatal inside the conn.Raw callback can cause unexpected behavior because it might not correctly terminate the test when called from a nested function. Instead, return the error from the callback and handle it after the conn.Raw call.

Suggested fix:

 conn.Raw(func(conn any) error {
     c, ok := conn.(*mysqlConn)
     if !ok {
         return fmt.Errorf("expected *mysqlConn, got %T", conn)
     }

     res, err := c.Exec("INSERT INTO test (value) VALUES ('a'), ('b')", nil)
     if err != nil {
-        t.Fatal(err)
+        return err
     }

     gtid, err = res.(Result).LastGTID()
     if err != nil {
-        t.Fatal(err)
+        return err
     }

     if gtid != "" {
-        t.Fatalf("expected empty gtid, got %v", gtid)
+        return fmt.Errorf("expected empty gtid, got %v", gtid)
     }

     _, err = c.Exec("SET SESSION session_track_gtids = ALL_GTIDS", nil)
     if err != nil {
-        t.Fatal(err)
+        return err
     }

     res, err = c.Exec("INSERT INTO test (value) VALUES ('a'), ('b')", nil)
     if err != nil {
-        t.Fatal(err)
+        return err
     }

     gtid, err = res.(Result).LastGTID()
     if err != nil {
-        t.Fatal(err)
+        return err
     }

     if gtid == "" {
-        t.Fatal("expected non-empty gtid")
+        return fmt.Errorf("expected non-empty gtid")
     }

     return nil
 })
+if err != nil {
+    t.Fatal(err)
+}

Committable suggestion was skipped due to low confidence.


return nil
})

var gtidExecuted string
err = conn.QueryRowContext(ctx, "SELECT @@global.gtid_executed").Scan(&gtidExecuted)
if err != nil {
dbt.Fatalf("%s", err.Error())
}

if gtidExecuted != gtid {
t.Fatalf("expected gtid %v, got %v", gtidExecuted, gtid)
}
})
}

// tests if rows are set in a proper state if some results were ignored before
// calling rows.NextResultSet.
func TestSkipResults(t *testing.T) {
Expand Down
12 changes: 12 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections
TrackSessionState bool // Enable session state tracking (e.g. GTID values)

// unexported fields. new options should be come here

Expand Down Expand Up @@ -323,6 +324,10 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true")
}

if cfg.TrackSessionState {
writeDSNParam(&buf, &hasParam, "trackSessionState", "true")
}

if len(cfg.ServerPubKey) > 0 {
writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey))
}
Expand Down Expand Up @@ -581,6 +586,13 @@ func parseDSNParams(cfg *Config, params string) (err error) {
return errors.New("invalid bool value: " + value)
}

case "trackSessionState":
var isBool bool
cfg.TrackSessionState, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}

// Server public key
case "serverPubKey":
name, err := url.QueryUnescape(value)
Expand Down
85 changes: 77 additions & 8 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,24 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro

if len(data) > pos {
// character set [1 byte]
// charset := data[pos + 1]
pos += 1

// status flags [2 bytes]
// statusFlags := binary.LittleEndian.Uint16(data[pos : pos + 2])
pos += 2

// capability flags (upper 2 bytes) [2 bytes]
upper := binary.LittleEndian.Uint16(data[pos : pos+2])
pos += 2

mc.flags += clientFlag((uint32(upper) << 16))

// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10

//mc.flags = flags
pos += 1 + 10

// second part of the password cipher [minimum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
Expand Down Expand Up @@ -277,6 +290,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientFlags |= clientMultiStatements
}

if mc.cfg.TrackSessionState {
clientFlags |= clientSessionTrack
}

// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLen := len(authResp)
Expand Down Expand Up @@ -530,6 +547,7 @@ func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)
mc.result.gtids = append(mc.result.gtids, "")

data, err := mc.conn().readPacket()
if err != nil {
Expand Down Expand Up @@ -638,16 +656,20 @@ func (mc *mysqlConn) clearResult() *okHandler {
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *okHandler) handleOkPacket(data []byte) error {
var n, m int
var offset, length int

var affectedRows, insertId uint64

// 0x00 [1 byte]
offset += 1

// Affected rows [Length Coded Binary]
affectedRows, _, n = readLengthEncodedInteger(data[1:])
affectedRows, _, length = readLengthEncodedInteger(data[offset:])
offset += length

// Insert id [Length Coded Binary]
insertId, _, m = readLengthEncodedInteger(data[1+n:])
insertId, _, length = readLengthEncodedInteger(data[offset:])
offset += length

// Update for the current statement result (only used by
// readResultSetHeaderPacket).
Expand All @@ -659,12 +681,59 @@ func (mc *okHandler) handleOkPacket(data []byte) error {
}

// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
if mc.status&statusMoreResultsExists != 0 {
return nil
}
mc.status = readStatus(data[offset : offset+2])
offset += 2

// warning count [2 bytes]
offset += 2

var gtid string
if (mc.flags & clientSessionTrack) == clientSessionTrack {
// Human readable status information (ignored)
num, _, length := readLengthEncodedInteger(data[offset:])
offset += length

offset += int(num)

if (mc.status & statusSessionStateChanged) == statusSessionStateChanged {
// Length of session state changes
num, _, length = readLengthEncodedInteger(data[offset:])
offset += length

for t := 0; t < int(num); {
infoType := data[offset]
offset += 1
t += 1

if infoType == sessionTrackGtids {
_, _, length := readLengthEncodedInteger(data[offset:])
offset += length
t += length

offset += 1
t += 1

gtidLength, _, length := readLengthEncodedInteger(data[offset:])
offset += length
t += length

gtid = string(data[offset : offset+int(gtidLength)])

offset += int(gtidLength)
t += int(gtidLength)
} else {
// increase the offset to skip the value
valueLength, _, length := readLengthEncodedInteger(data[offset:])
offset += length + int(valueLength)
t += length + int(valueLength)
}
}
}

if len(mc.result.gtids) > 0 {
mc.result.gtids[len(mc.result.gtids)-1] = gtid
}
}

return nil
}
Expand Down
17 changes: 17 additions & 0 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@ import "database/sql/driver"
// res.(mysql.Result).AllRowsAffected()
type Result interface {
driver.Result

// LastGTID returns the GTID of the last result, if available.
LastGTID() (string, error)

// AllLastGTIDs returns a slice containing
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Incomplete method comment for AllLastGTIDs().

The comment for AllLastGTIDs() is incomplete:

// AllLastGTIDs returns a slice containing

Please complete the comment to accurately describe what the method returns.

AllLastGTIDs() []string

// AllRowsAffected returns a slice containing the affected rows for each
// executed statement.
AllRowsAffected() []int64

// AllLastInsertIds returns a slice containing the last inserted ID for each
// executed statement.
AllLastInsertIds() []int64
Expand All @@ -31,6 +39,7 @@ type mysqlResult struct {
// One entry in both slices is created for every executed statement result.
affectedRows []int64
insertIds []int64
gtids []string
}

func (res *mysqlResult) LastInsertId() (int64, error) {
Expand All @@ -41,10 +50,18 @@ func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows[len(res.affectedRows)-1], nil
}

func (res *mysqlResult) LastGTID() (string, error) {
return res.gtids[len(res.gtids)-1], nil
}
Comment on lines +53 to +55
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential panic in LastGTID() due to empty gtids slice.

If res.gtids is empty, accessing res.gtids[len(res.gtids)-1] will cause an index out of range panic.

Consider adding a check to ensure res.gtids is not empty before accessing the last element:

func (res *mysqlResult) LastGTID() (string, error) {
+	if len(res.gtids) == 0 {
+		return "", errors.New("no GTIDs available")
+	}
	return res.gtids[len(res.gtids)-1], nil
}

Alternatively, define the expected behavior when gtids is empty and handle it appropriately.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func (res *mysqlResult) LastGTID() (string, error) {
return res.gtids[len(res.gtids)-1], nil
}
func (res *mysqlResult) LastGTID() (string, error) {
if len(res.gtids) == 0 {
return "", errors.New("no GTIDs available")
}
return res.gtids[len(res.gtids)-1], nil
}


func (res *mysqlResult) AllLastInsertIds() []int64 {
return append([]int64{}, res.insertIds...) // defensive copy
}

func (res *mysqlResult) AllRowsAffected() []int64 {
return append([]int64{}, res.affectedRows...) // defensive copy
}

func (res *mysqlResult) AllLastGTIDs() []string {
return append([]string{}, res.gtids...)
}