From 98f445cc82fa2fe16d56f2eafb822599f0e9fdfc Mon Sep 17 00:00:00 2001 From: Diego Dupin Date: Mon, 31 Mar 2025 18:04:08 +0200 Subject: [PATCH 1/2] test stability improvement. * ensuring performance schema is enabled when testing some performance schema results * Added logic to check if the default collation is overridden by the server character_set_collations * ensure using IANA timezone in test, since tzinfo depending on system won't have deprecated tz like "US/Central" and "US/Pacific" --- driver_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/driver_test.go b/driver_test.go index 00e82865..8569494e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1630,13 +1630,46 @@ func TestCollation(t *testing.T) { } runTests(t, tdsn, func(dbt *DBTest) { + // see https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation + // when character_set_collations is set for the charset, it overrides the default collation + // so we need to check if the default collation is overridden + forceExpected := expected + var defaultCollations string + err := dbt.db.QueryRow("SELECT @@character_set_collations").Scan(&defaultCollations) + if err == nil { + // Query succeeded, need to check if we should override expected collation + collationMap := make(map[string]string) + pairs := strings.Split(defaultCollations, ",") + for _, pair := range pairs { + parts := strings.Split(pair, "=") + if len(parts) == 2 { + collationMap[parts[0]] = parts[1] + } + } + + // Get charset prefix from expected collation + parts := strings.Split(expected, "_") + if len(parts) > 0 { + charset := parts[0] + if newCollation, ok := collationMap[charset]; ok { + forceExpected = newCollation + } + } + } + var got string if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { dbt.Fatal(err) } if got != expected { - dbt.Fatalf("expected connection collation %s but got %s", expected, got) + if forceExpected != expected { + if got != forceExpected { + dbt.Fatalf("expected forced connection collation %s but got %s", forceExpected, got) + } + } else { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } } }) } @@ -1685,7 +1718,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) { } func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + zones := []string{"UTC", "America/New_York", "Asia/Hong_Kong", "Local"} // Regression test for timezone handling tzTest := func(dbt *DBTest) { @@ -1693,8 +1726,8 @@ func TestTimezoneConversion(t *testing.T) { dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") // Insert local time into database (should be converted) - usCentral, _ := time.LoadLocation("US/Central") - reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + newYorkTz, _ := time.LoadLocation("America/New_York") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(newYorkTz) dbt.mustExec("INSERT INTO test VALUE (?)", reftime) // Retrieve time from DB @@ -1713,7 +1746,7 @@ func TestTimezoneConversion(t *testing.T) { // Check that dates match if reftime.Unix() != dbTime.Unix() { dbt.Errorf("times do not match.\n") - dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(%v)=%v\n", newYorkTz, reftime) dbt.Errorf(" Now(UTC)=%v\n", dbTime) } } @@ -3541,6 +3574,15 @@ func TestConnectionAttributes(t *testing.T) { dbt := &DBTest{t, db} + var varName string + var varValue string + err := dbt.db.QueryRow("SHOW VARIABLES LIKE 'performance_schema'").Scan(&varName, &varValue) + if err != nil { + t.Fatalf("error: %s", err.Error()) + } + if varValue != "ON" { + t.Skipf("Performance schema is not enabled. skipping") + } queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" rows := dbt.mustQuery(queryString) defer rows.Close() From f0eee818234dda9e29ae8c78ed63e1c5ebd73041 Mon Sep 17 00:00:00 2001 From: Diego Dupin Date: Fri, 18 Apr 2025 18:55:01 +0200 Subject: [PATCH 2/2] readPacket optimization Since configuration options doesn't change at runtime, after connection is established, use dedicated function, in order to avoid multiple test test compress, checking ReadTimeout configuration option --- AUTHORS | 1 + benchmark_test.go | 41 ++++++++++++++++++++++++++++++++++++++++ compress_test.go | 2 ++ connection.go | 14 +++----------- connection_test.go | 7 ++++++- connector.go | 18 ++++++++++++++++++ packets.go | 10 +++------- packets_test.go | 47 ++++++++++++++++++++++++++++++---------------- 8 files changed, 105 insertions(+), 35 deletions(-) diff --git a/AUTHORS b/AUTHORS index 510b869b..a261819f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -37,6 +37,7 @@ Daniel Montoya Daniel Nichter Daniƫl van Eeden Dave Protasowski +Diego Dupin Dirkjan Bussink DisposaBoy Egor Smolyakov diff --git a/benchmark_test.go b/benchmark_test.go index 5c9a046b..92974f4d 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -113,6 +113,47 @@ func benchmarkQueryHelper(b *testing.B, compr bool) { } } +func BenchmarkSelect10000rows(b *testing.B) { + db := initDB(b, false) + defer db.Close() + + // Check if we're using MariaDB + var version string + err := db.QueryRow("SELECT @@version").Scan(&version) + if err != nil { + b.Fatalf("Failed to get server version: %v", err) + } + + if !strings.Contains(strings.ToLower(version), "mariadb") { + b.Skip("Skipping benchmark as it requires MariaDB sequence table") + return + } + + b.StartTimer() + stmt, err := db.Prepare("SELECT * FROM seq_1_to_10000") + if err != nil { + b.Fatalf("Failed to prepare statement: %v", err) + } + defer stmt.Close() + for n := 0; n < b.N; n++ { + rows, err := stmt.Query() + if err != nil { + b.Fatalf("Failed to query 10000rows: %v", err) + } + + var id int64 + for rows.Next() { + err = rows.Scan(&id) + if err != nil { + rows.Close() + b.Fatalf("Failed to scan row: %v", err) + } + } + rows.Close() + } + b.StopTimer() +} + func BenchmarkExec(b *testing.B) { tb := (*TB)(b) b.StopTimer() diff --git a/compress_test.go b/compress_test.go index 030deaef..72696a4d 100644 --- a/compress_test.go +++ b/compress_test.go @@ -40,6 +40,8 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []by conn := new(mockConn) conn.data = compressedPacket mc.netConn = conn + mc.readNextFunc = mc.compIO.readNext + mc.readFunc = conn.Read uncompressedPacket, err := mc.readPacket() if err != nil { diff --git a/connection.go b/connection.go index 3e455a3f..64b5a502 100644 --- a/connection.go +++ b/connection.go @@ -39,6 +39,8 @@ type mysqlConn struct { compressSequence uint8 parseTime bool compress bool + readFunc func([]byte) (int, error) + readNextFunc func(int, readerFunc) ([]byte, error) // for context support (Go 1.8+) watching bool @@ -64,16 +66,6 @@ func (mc *mysqlConn) log(v ...any) { mc.cfg.Logger.Print(v...) } -func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) { - to := mc.cfg.ReadTimeout - if to > 0 { - if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil { - return 0, err - } - } - return mc.netConn.Read(b) -} - func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) { to := mc.cfg.WriteTimeout if to > 0 { @@ -247,7 +239,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin // can not take the buffer. Something must be wrong with the connection mc.cleanup() // interpolateParams would be called before sending any query. - // So its safe to retry. + // So it's safe to retry. return "", driver.ErrBadConn } buf = buf[:0] diff --git a/connection_test.go b/connection_test.go index f7740898..e8091323 100644 --- a/connection_test.go +++ b/connection_test.go @@ -18,12 +18,17 @@ import ( ) func TestInterpolateParams(t *testing.T) { + buf := newBuffer() + nc := &net.TCPConn{} mc := &mysqlConn{ - buf: newBuffer(), + buf: buf, + netConn: nc, maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, + readNextFunc: buf.readNext, + readFunc: nc.Read, } q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) diff --git a/connector.go b/connector.go index bc1d46af..ea121923 100644 --- a/connector.go +++ b/connector.go @@ -16,6 +16,7 @@ import ( "os" "strconv" "strings" + "time" ) type connector struct { @@ -130,6 +131,22 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer() + // setting readNext/read functions + mc.readNextFunc = mc.buf.readNext + + // Initialize read function based on configuration + if mc.cfg.ReadTimeout > 0 { + mc.readFunc = func(b []byte) (int, error) { + deadline := time.Now().Add(mc.cfg.ReadTimeout) + if err := mc.netConn.SetReadDeadline(deadline); err != nil { + return 0, err + } + return mc.netConn.Read(b) + } + } else { + mc.readFunc = mc.netConn.Read + } + // Reading Handshake Initialization Packet authData, plugin, err := mc.readHandshakePacket() if err != nil { @@ -170,6 +187,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { if mc.cfg.compress && mc.flags&clientCompress == clientCompress { mc.compress = true mc.compIO = newCompIO(mc) + mc.readNextFunc = mc.compIO.readNext } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket diff --git a/packets.go b/packets.go index 4b836216..b08a4139 100644 --- a/packets.go +++ b/packets.go @@ -30,14 +30,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte invalidSequence := false - readNext := mc.buf.readNext - if mc.compress { - readNext = mc.compIO.readNext - } - for { // read packet header - data, err := readNext(4, mc.readWithTimeout) + data, err := mc.readNextFunc(4, mc.readFunc) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -85,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = readNext(pktLen, mc.readWithTimeout) + data, err = mc.readNextFunc(pktLen, mc.readFunc) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -369,6 +364,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return err } mc.netConn = tlsConn + mc.readFunc = mc.netConn.Read } // User [null terminated string] diff --git a/packets_test.go b/packets_test.go index 694b0564..71b071a8 100644 --- a/packets_test.go +++ b/packets_test.go @@ -97,24 +97,30 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) connector := newConnector(NewConfig()) + buf := newBuffer() mc := &mysqlConn{ - buf: newBuffer(), + buf: buf, cfg: connector.cfg, connector: connector, netConn: conn, closech: make(chan struct{}), maxAllowedPacket: defaultMaxAllowedPacket, sequence: sequence, + readNextFunc: buf.readNext, + readFunc: conn.Read, } return conn, mc } func TestReadPacketSingleByte(t *testing.T) { conn := new(mockConn) + buf := newBuffer() mc := &mysqlConn{ - netConn: conn, - buf: newBuffer(), - cfg: NewConfig(), + netConn: conn, + buf: buf, + cfg: NewConfig(), + readNextFunc: buf.readNext, + readFunc: conn.Read, } conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -165,10 +171,13 @@ func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) + buf := newBuffer() mc := &mysqlConn{ - netConn: conn, - buf: newBuffer(), - cfg: NewConfig(), + netConn: conn, + buf: buf, + cfg: NewConfig(), + readNextFunc: buf.readNext, + readFunc: conn.Read, } data := make([]byte, maxPacketSize*2+4*3) @@ -272,11 +281,14 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) + buf := newBuffer() mc := &mysqlConn{ - netConn: conn, - buf: newBuffer(), - closech: make(chan struct{}), - cfg: NewConfig(), + netConn: conn, + buf: buf, + closech: make(chan struct{}), + cfg: NewConfig(), + readNextFunc: buf.readNext, + readFunc: conn.Read, } // illegal empty (stand-alone) packet @@ -317,12 +329,15 @@ func TestReadPacketFail(t *testing.T) { // not-NUL terminated plugin_name in init packet func TestRegression801(t *testing.T) { conn := new(mockConn) + buf := newBuffer() mc := &mysqlConn{ - netConn: conn, - buf: newBuffer(), - cfg: new(Config), - sequence: 42, - closech: make(chan struct{}), + netConn: conn, + buf: buf, + cfg: new(Config), + sequence: 42, + closech: make(chan struct{}), + readNextFunc: buf.readNext, + readFunc: conn.Read, } conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,