From 98f445cc82fa2fe16d56f2eafb822599f0e9fdfc Mon Sep 17 00:00:00 2001 From: Diego Dupin Date: Mon, 31 Mar 2025 18:04:08 +0200 Subject: [PATCH 1/4] test stability improvement. * ensuring performance schema is enabled when testing some performance schema results * Added logic to check if the default collation is overridden by the server character_set_collations * ensure using IANA timezone in test, since tzinfo depending on system won't have deprecated tz like "US/Central" and "US/Pacific" --- driver_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/driver_test.go b/driver_test.go index 00e82865..8569494e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1630,13 +1630,46 @@ func TestCollation(t *testing.T) { } runTests(t, tdsn, func(dbt *DBTest) { + // see https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation + // when character_set_collations is set for the charset, it overrides the default collation + // so we need to check if the default collation is overridden + forceExpected := expected + var defaultCollations string + err := dbt.db.QueryRow("SELECT @@character_set_collations").Scan(&defaultCollations) + if err == nil { + // Query succeeded, need to check if we should override expected collation + collationMap := make(map[string]string) + pairs := strings.Split(defaultCollations, ",") + for _, pair := range pairs { + parts := strings.Split(pair, "=") + if len(parts) == 2 { + collationMap[parts[0]] = parts[1] + } + } + + // Get charset prefix from expected collation + parts := strings.Split(expected, "_") + if len(parts) > 0 { + charset := parts[0] + if newCollation, ok := collationMap[charset]; ok { + forceExpected = newCollation + } + } + } + var got string if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { dbt.Fatal(err) } if got != expected { - dbt.Fatalf("expected connection collation %s but got %s", expected, got) + if forceExpected != expected { + if got != forceExpected { + dbt.Fatalf("expected forced connection collation %s but got %s", forceExpected, got) + } + } else { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } } }) } @@ -1685,7 +1718,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) { } func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + zones := []string{"UTC", "America/New_York", "Asia/Hong_Kong", "Local"} // Regression test for timezone handling tzTest := func(dbt *DBTest) { @@ -1693,8 +1726,8 @@ func TestTimezoneConversion(t *testing.T) { dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") // Insert local time into database (should be converted) - usCentral, _ := time.LoadLocation("US/Central") - reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + newYorkTz, _ := time.LoadLocation("America/New_York") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(newYorkTz) dbt.mustExec("INSERT INTO test VALUE (?)", reftime) // Retrieve time from DB @@ -1713,7 +1746,7 @@ func TestTimezoneConversion(t *testing.T) { // Check that dates match if reftime.Unix() != dbTime.Unix() { dbt.Errorf("times do not match.\n") - dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(%v)=%v\n", newYorkTz, reftime) dbt.Errorf(" Now(UTC)=%v\n", dbTime) } } @@ -3541,6 +3574,15 @@ func TestConnectionAttributes(t *testing.T) { dbt := &DBTest{t, db} + var varName string + var varValue string + err := dbt.db.QueryRow("SHOW VARIABLES LIKE 'performance_schema'").Scan(&varName, &varValue) + if err != nil { + t.Fatalf("error: %s", err.Error()) + } + if varValue != "ON" { + t.Skipf("Performance schema is not enabled. skipping") + } queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" rows := dbt.mustQuery(queryString) defer rows.Close() From 8e1f894eac45eeea3a3c62a9b84ef1dd02528bfc Mon Sep 17 00:00:00 2001 From: rusher Date: Thu, 27 Mar 2025 21:33:59 +0100 Subject: [PATCH 2/4] Implement MariaDB metadata skipping. Refactor handshake packet handling to support extended capabilities Updated the readHandshakePacket and writeHandshakeResponsePacket functions to include server capabilities and extended capabilities. Adjusted related tests and connection logic to accommodate these changes, ensuring compatibility with MariaDB and improved handling of client capabilities. --- auth_test.go | 28 +++++----- connection.go | 47 +++++++++------- connector.go | 6 +-- const.go | 18 ++++++- packets.go | 141 ++++++++++++++++++++++++++++-------------------- packets_test.go | 10 +++- rows.go | 2 +- statement.go | 20 +++++-- 8 files changed, 171 insertions(+), 101 deletions(-) diff --git a/auth_test.go b/auth_test.go index 46e1e3b4..a8f1d4bd 100644 --- a/auth_test.go +++ b/auth_test.go @@ -89,7 +89,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -176,7 +176,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -232,7 +232,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -284,7 +284,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -357,7 +357,7 @@ func TestAuthFastCleartextPassword(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -400,7 +400,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -459,7 +459,7 @@ func TestAuthFastNativePassword(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -502,7 +502,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -544,7 +544,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -592,7 +592,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -641,7 +641,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.TLS = nil - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -1343,7 +1343,7 @@ func TestEd25519Auth(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } diff --git a/connection.go b/connection.go index 3e455a3f..b0e01280 100644 --- a/connection.go +++ b/connection.go @@ -24,21 +24,22 @@ import ( ) type mysqlConn struct { - buf buffer - netConn net.Conn - rawConn net.Conn // underlying connection when netConn is TLS connection. - result mysqlResult // managed by clearResult() and handleOkPacket(). - compIO *compIO - cfg *Config - connector *connector - maxAllowedPacket int - maxWriteSize int - flags clientFlag - status statusFlag - sequence uint8 - compressSequence uint8 - parseTime bool - compress bool + buf buffer + netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. + result mysqlResult // managed by clearResult() and handleOkPacket(). + compIO *compIO + cfg *Config + connector *connector + maxAllowedPacket int + maxWriteSize int + clientCapabilities capabilityFlag + clientExtCapabilities extendedCapabilityFlag + status statusFlag + sequence uint8 + compressSequence uint8 + parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -229,7 +230,15 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } if columnCount > 0 { - err = mc.readUntilEOF() + if mc.clientExtCapabilities&clientCacheMetadata != 0 { + stmt.columns, err = mc.readColumns(int(columnCount)) + if err != nil { + return nil, err + } + } else { + // skip column definition packets and intermediate EOF packet + err = mc.readUntilEOF() + } } } @@ -370,7 +379,7 @@ func (mc *mysqlConn) exec(query string) error { } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err != nil { return err } @@ -419,7 +428,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) // Read Result var resLen int - resLen, err = handleOk.readResultSetHeaderPacket() + resLen, _, err = handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -453,7 +462,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc diff --git a/connector.go b/connector.go index bc1d46af..fec1c3dd 100644 --- a/connector.go +++ b/connector.go @@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer() // Reading Handshake Initialization Packet - authData, plugin, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtendedCapabilities, plugin, err := mc.readHandshakePacket() if err != nil { mc.cleanup() return nil, err @@ -153,7 +153,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + if err = mc.writeHandshakeResponsePacket(authResp, serverCapabilities, serverExtendedCapabilities, plugin); err != nil { mc.cleanup() return nil, err } @@ -167,7 +167,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + if mc.cfg.compress && mc.clientCapabilities&clientCompress > 0 { mc.compress = true mc.compIO = newCompIO(mc) } diff --git a/const.go b/const.go index 4aadcd64..b33b1452 100644 --- a/const.go +++ b/const.go @@ -43,10 +43,10 @@ const ( ) // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags -type clientFlag uint32 +type capabilityFlag uint32 const ( - clientLongPassword clientFlag = 1 << iota + clientMySQL capabilityFlag = 1 << iota clientFoundRows clientLongFlag clientConnectWithDB @@ -73,6 +73,20 @@ const ( clientDeprecateEOF ) +// https://mariadb.com/kb/en/connection/#capabilities +type extendedCapabilityFlag uint32 + +const ( + progressIndicator extendedCapabilityFlag = 1 << iota + clientComMulti + clientStmtBulkOperations + clientExtendedMetadata + clientCacheMetadata + clientUnitBulkResult +) + +// https://mariadb.com/kb/en/connection/#capabilities + const ( comQuit byte = iota + 1 comInitDB diff --git a/packets.go b/packets.go index 4b836216..a46d0e37 100644 --- a/packets.go +++ b/packets.go @@ -174,19 +174,19 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { +func (mc *mysqlConn) readHandshakePacket() (data []byte, serverCapabilities capabilityFlag, serverExtendedCapabilities extendedCapabilityFlag, plugin string, err error) { data, err = mc.readPacket() if err != nil { return } if data[0] == iERR { - return nil, "", mc.handleErrorPacket(data) + return nil, 0, 0, "", mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, "", fmt.Errorf( + return nil, 0, 0, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -204,15 +204,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro pos += 8 + 1 // capability flags (lower 2 bytes) [2 bytes] - mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) - if mc.flags&clientProtocol41 == 0 { - return nil, "", ErrOldProtocol + serverCapabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if serverCapabilities&clientProtocol41 == 0 { + return nil, serverCapabilities, 0, "", ErrOldProtocol } - if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { + if serverCapabilities&clientSSL == 0 && mc.cfg.TLS != nil { if mc.cfg.AllowFallbackToPlaintext { mc.cfg.TLS = nil } else { - return nil, "", ErrNoTLS + return nil, serverCapabilities, 0, "", ErrNoTLS } } pos += 2 @@ -222,11 +222,16 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // status flags [2 bytes] pos += 3 // capability flags (upper 2 bytes) [2 bytes] - mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + serverCapabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 pos += 2 // length of auth-plugin-data [1 byte] - // reserved (all [00]) [10 bytes] - pos += 11 + // reserved (all [00]) [6 bytes] + pos += 7 + if serverCapabilities&clientMySQL == 0 { + // MariaDB server, use extended capability flags + serverExtendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4])) + } + pos += 4 // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -250,67 +255,73 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro } else { plugin = string(data[pos:]) } - // make a memory safe copy of the cipher slice var b [20]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], serverCapabilities, serverExtendedCapabilities, plugin, nil } // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], serverCapabilities, 0, plugin, nil } -// Client Authentication Packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { - // Adjust client flags based on server support - clientFlags := clientProtocol41 | - clientSecureConn | - clientLongPassword | - clientTransactions | - clientLocalFiles | - clientPluginAuth | - clientMultiResults | - mc.flags&clientConnectAttrs | - mc.flags&clientLongFlag - - sendConnectAttrs := mc.flags&clientConnectAttrs != 0 - - if mc.cfg.ClientFoundRows { - clientFlags |= clientFoundRows +// initClientCapabilities initializes the client capabilities based on server support and configuration +func (mc *mysqlConn) initClientCapabilities(serverCapabilities capabilityFlag, cfg *Config) capabilityFlag { + + clientCapabilities := + clientMySQL | + clientLongFlag | + clientIgnoreSpace | + clientProtocol41 | + clientSecureConn | + clientTransactions | + clientPluginAuthLenEncClientData | + clientLocalFiles | + clientPluginAuth | + clientMultiResults | + clientConnectAttrs + + if cfg.ClientFoundRows { + clientCapabilities |= clientFoundRows } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { - clientFlags |= clientCompress + if cfg.compress { + clientCapabilities |= clientCompress } // To enable TLS / SSL if mc.cfg.TLS != nil { - clientFlags |= clientSSL + clientCapabilities |= clientSSL } if mc.cfg.MultiStatements { - clientFlags |= clientMultiStatements + clientCapabilities |= clientMultiStatements } + if n := len(cfg.DBName); n > 0 { + clientCapabilities |= clientConnectWithDB + } + + return clientCapabilities & serverCapabilities +} + +// Client Authentication Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, serverCapabilities capabilityFlag, serverExtendedCapabilities extendedCapabilityFlag, plugin string) error { + // Adjust client flags based on server support + mc.clientCapabilities = mc.initClientCapabilities(serverCapabilities, mc.cfg) + mc.clientExtCapabilities = clientCacheMetadata & serverExtendedCapabilities + + sendConnectAttrs := mc.clientCapabilities&clientConnectAttrs != 0 // encode length of the auth plugin data var authRespLEIBuf [9]byte authRespLen := len(authResp) authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) - if len(authRespLEI) > 1 { - // if the length can not be written in 1 byte, it must be written as a - // length encoded integer - clientFlags |= clientPluginAuthLenEncClientData - } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 // To specify a db name - if n := len(mc.cfg.DBName); n > 0 { - clientFlags |= clientConnectWithDB - pktLen += n + 1 - } + pktLen += len(mc.cfg.DBName) + 1 // encode length of the connection attributes var connAttrsLEI []byte @@ -328,8 +339,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return err } - // ClientFlags [32 bit] - binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags)) + // clientCapabilities [32 bit] + binary.LittleEndian.PutUint32(data[4:], uint32(mc.clientCapabilities)) // MaxPacketSize [32 bit] (none) binary.LittleEndian.PutUint32(data[8:], 0) @@ -348,9 +359,18 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // Filler [23 bytes] (all 0x00) pos := 13 - for ; pos < 13+23; pos++ { + for ; pos < 13+19; pos++ { data[pos] = 0 } + if mc.clientCapabilities&clientMySQL == 0 { + // clientExtendedCapabilities [32 bit] + binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.clientExtCapabilities)) + pos += 4 + } else { + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + } // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -385,9 +405,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { pos += copy(data[pos:], mc.cfg.DBName) - data[pos] = 0x00 - pos++ } + data[pos] = 0x00 + pos++ pos += copy(data[pos:], plugin) data[pos] = 0x00 @@ -535,32 +555,37 @@ func (mc *okHandler) readResultOK() error { // Result Set Header Packet // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html -func (mc *okHandler) readResultSetHeaderPacket() (int, error) { +func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) { // handleOkPacket replaces both values; other cases leave the values unchanged. mc.result.affectedRows = append(mc.result.affectedRows, 0) mc.result.insertIds = append(mc.result.insertIds, 0) data, err := mc.conn().readPacket() if err != nil { - return 0, err + return 0, false, err } switch data[0] { case iOK: - return 0, mc.handleOkPacket(data) + return 0, false, mc.handleOkPacket(data) case iERR: - return 0, mc.conn().handleErrorPacket(data) + return 0, false, mc.conn().handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) + return 0, false, mc.handleInFileRequest(string(data[1:])) } // column count // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html - num, _, _ := readLengthEncodedInteger(data) + // https://mariadb.com/kb/en/result-set-packets/#column-count-packet + num, _, len := readLengthEncodedInteger(data) + + if mc.clientExtCapabilities&clientCacheMetadata != 0 { + return int(num), data[len] == 0x01, nil + } // ignore remaining data in the packet. see #1478. - return int(num), nil + return int(num), true, nil } // Error Packet @@ -1176,7 +1201,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // mc.affectedRows and mc.insertIds. func (mc *okHandler) discardResults() error { for mc.status&statusMoreResultsExists != 0 { - resLen, err := mc.readResultSetHeaderPacket() + resLen, _, err := mc.readResultSetHeaderPacket() if err != nil { return err } diff --git a/packets_test.go b/packets_test.go index 694b0564..b487051e 100644 --- a/packets_test.go +++ b/packets_test.go @@ -332,11 +332,19 @@ func TestRegression801(t *testing.T) { 112, 97, 115, 115, 119, 111, 114, 100} conn.maxReads = 1 - authData, pluginName, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtendedCapabilities, pluginName, err := mc.readHandshakePacket() if err != nil { t.Fatalf("got error: %v", err) } + if serverCapabilities != 2148530143 { + t.Fatalf("expected serverCapabilities to be 2148530143, got %v", serverCapabilities) + } + + if serverExtendedCapabilities != 0 { + t.Fatalf("expected serverExtendedCapabilities to be 0, got %v", serverExtendedCapabilities) + } + if pluginName != "mysql_native_password" { t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName) } diff --git a/rows.go b/rows.go index df98417b..e77cea5c 100644 --- a/rows.go +++ b/rows.go @@ -156,7 +156,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { rows.rs = resultSet{} // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to // nextResultSet. - resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() + resLen, _, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() if err != nil { // Clean up about multi-results flag rows.rs.done = true diff --git a/statement.go b/statement.go index 35df8545..e2f9053f 100644 --- a/statement.go +++ b/statement.go @@ -20,6 +20,7 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int + columns []mysqlField } func (stmt *mysqlStmt) Close() error { @@ -64,7 +65,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { handleOk := stmt.mc.clearResult() // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -107,7 +108,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Read Result handleOk := stmt.mc.clearResult() - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -116,7 +117,20 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if resLen > 0 { rows.mc = mc - rows.rs.columns, err = mc.readColumns(resLen) + if metadataFollows { + rows.rs.columns, err = mc.readColumns(resLen) + if err != nil { + return nil, err + } + stmt.columns = rows.rs.columns + } else { + // skip EOF Packet + _, err := mc.readPacket() + if err != nil { + return nil, err + } + rows.rs.columns = stmt.columns + } } else { rows.rs.done = true From be6107b9ecef368be0c071f3114ec4dd29e4a0a7 Mon Sep 17 00:00:00 2001 From: Diego Dupin Date: Tue, 1 Apr 2025 15:52:53 +0200 Subject: [PATCH 3/4] Packet parsing small improvement * correct packet readLengthEncodedString that was returning byte array to readLengthEncodedBytes * have an readLengthEncodedString that effectively return string * faster column parsing: MariaDB/MySQL have an identifier limitation of 64 characters (https://dev.mysql.com/doc/refman/8.4/en/identifier-length.html) before: BenchmarkReceiveMetadata-16 1846 650394 ns/op 138776 B/op 3024 allocs/op after: BenchmarkReceiveMetadata-16 1772 639809 ns/op 138776 B/op 3024 allocs/op --- benchmark_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++ packets.go | 39 ++++++++++++------------------------ utils.go | 35 +++++++++++++++++++++++---------- 3 files changed, 87 insertions(+), 37 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 5c9a046b..8275ebc4 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -440,3 +440,53 @@ func BenchmarkReceiveMassiveRows(b *testing.B) { } }) } + +// BenchmarkReceiveMetadata measures performance of receiving more metadata than real data +func BenchmarkReceiveMetadata(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + + // Create a table with 1000 integer fields + createTableQuery := "CREATE TABLE large_integer_table (" + for i := 0; i < 1000; i++ { + createTableQuery += fmt.Sprintf("col_%d INT", i) + if i < 999 { + createTableQuery += ", " + } + } + createTableQuery += ")" + + // Initialize database + db := initDB(b, false, + "DROP TABLE IF EXISTS large_integer_table", + createTableQuery, + "INSERT INTO large_integer_table VALUES ("+ + strings.Repeat("0,", 999)+"0)", // Insert a row of zeros + ) + defer db.Close() + + // Prepare a SELECT query to retrieve metadata + stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1")) + defer stmt.Close() + + b.StartTimer() + + // Benchmark metadata retrieval + for i := 0; i < b.N; i++ { + rows := tb.checkRows(stmt.Query()) + + // Create a slice to scan all columns + values := make([]interface{}, 1000) + valuePtrs := make([]interface{}, 1000) + for j := range values { + valuePtrs[j] = &values[j] + } + rows.Next() + // Scan the row + err := rows.Scan(valuePtrs...) + tb.check(err) + + rows.Close() + } +} diff --git a/packets.go b/packets.go index a46d0e37..080e61e4 100644 --- a/packets.go +++ b/packets.go @@ -724,28 +724,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { } // Catalog - pos, err := skipLengthEncodedString(data) - if err != nil { - return nil, err - } - + pos := int(data[0]) + 1 // Database [len coded string] - n, err := skipLengthEncodedString(data[pos:]) - if err != nil { - return nil, err - } - pos += n + pos += int(data[pos]) + 1 - // Table [len coded string] + // Table alias [len coded string] + // alias length can be up to 256 if mc.cfg.ColumnsWithAlias { tableName, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err } pos += n - columns[i].tableName = string(tableName) + columns[i].tableName = tableName } else { - n, err = skipLengthEncodedString(data[pos:]) + n, err := skipLengthEncodedString(data[pos:]) if err != nil { return nil, err } @@ -753,26 +746,18 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { } // Original table [len coded string] - n, err = skipLengthEncodedString(data[pos:]) - if err != nil { - return nil, err - } - pos += n + pos += int(data[pos]) + 1 - // Name [len coded string] + // Name alias [len coded string] name, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err } - columns[i].name = string(name) + columns[i].name = name pos += n // Original name [len coded string] - n, err = skipLengthEncodedString(data[pos:]) - if err != nil { - return nil, err - } - pos += n + pos += int(data[pos]) + 1 // Filler [uint8] pos++ @@ -843,7 +828,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { for i := range dest { // Read bytes and convert to string var buf []byte - buf, isNull, n, err = readLengthEncodedString(data[pos:]) + buf, isNull, n, err = readLengthEncodedBytes(data[pos:]) pos += n if err != nil { @@ -1322,7 +1307,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { fieldTypeVector: var isNull bool var n int - dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + dest[i], isNull, n, err = readLengthEncodedBytes(data[pos:]) pos += n if err == nil { if !isNull { diff --git a/utils.go b/utils.go index 8716c26c..92445a28 100644 --- a/utils.go +++ b/utils.go @@ -524,10 +524,7 @@ func uint64ToString(n uint64) []byte { return a[i:] } -// returns the string read as a bytes slice, whether the value is NULL, -// the number of bytes read and an error, in case the string is longer than -// the input slice -func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { +func readLengthEncodedBytes(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := readLengthEncodedInteger(b) if num < 1 { @@ -543,6 +540,25 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { return nil, false, n, io.EOF } +// returns the string read as a bytes slice, whether the value is NULL, +// the number of bytes read and an error, in case the string is longer than +// the input slice +func readLengthEncodedString(b []byte) (string, bool, int, error) { + // Get length + num, isNull, n := readLengthEncodedInteger(b) + if num < 1 { + return "", isNull, n, nil + } + + n += int(num) + + // Check data length + if len(b) >= n { + return string(b[n-int(num) : n : n]), false, n, nil + } + return "", false, n, io.EOF +} + // returns the number of bytes skipped and an error, in case the string is // longer than the input slice func skipLengthEncodedString(b []byte) (int, error) { @@ -567,7 +583,9 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { if len(b) == 0 { return 0, true, 1 } - + if b[0] < 251 { + return uint64(b[0]), false, 1 + } switch b[0] { // 251: NULL case 0xfb: @@ -582,12 +600,9 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { return uint64(getUint24(b[1:])), false, 4 // 254: value of following 8 - case 0xfe: - return uint64(binary.LittleEndian.Uint64(b[1:])), false, 9 + default: + return binary.LittleEndian.Uint64(b[1:]), false, 9 } - - // 0-250: value of first byte - return uint64(b[0]), false, 1 } // encodes a uint64 value and appends it to the given bytes slice From 9908c108e82302617fded25aafe44c7db6b7dcf4 Mon Sep 17 00:00:00 2001 From: Diego Dupin Date: Wed, 2 Apr 2025 17:54:01 +0200 Subject: [PATCH 4/4] Implement clientDeprecateEOF flag. (not real performance improvement, but will permit further enhancement) --- AUTHORS | 1 + connection.go | 18 ++++----- packets.go | 104 +++++++++++++++++++++++++++++++++++--------------- rows.go | 4 +- statement.go | 11 ++---- 5 files changed, 89 insertions(+), 49 deletions(-) diff --git a/AUTHORS b/AUTHORS index 510b869b..a261819f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -37,6 +37,7 @@ Daniel Montoya Daniel Nichter Daniƫl van Eeden Dave Protasowski +Diego Dupin Dirkjan Bussink DisposaBoy Egor Smolyakov diff --git a/connection.go b/connection.go index b0e01280..cd84f29d 100644 --- a/connection.go +++ b/connection.go @@ -224,20 +224,20 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { + if err = mc.skipColumns(stmt.paramCount); err != nil { return nil, err } } if columnCount > 0 { if mc.clientExtCapabilities&clientCacheMetadata != 0 { - stmt.columns, err = mc.readColumns(int(columnCount)) - if err != nil { + if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil { return nil, err } } else { - // skip column definition packets and intermediate EOF packet - err = mc.readUntilEOF() + if err = mc.skipColumns(int(columnCount)); err != nil { + return nil, err + } } } } @@ -386,12 +386,12 @@ func (mc *mysqlConn) exec(query string) error { if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipColumns(resLen); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipResultSetRows(); err != nil { return err } } @@ -470,14 +470,14 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if resLen > 0 { // Columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipColumns(resLen); err != nil { return nil, err } } dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { - return dest[0].([]byte), mc.readUntilEOF() + return dest[0].([]byte), mc.skipResultSetRows() } } return nil, err diff --git a/packets.go b/packets.go index 080e61e4..1bdf2aea 100644 --- a/packets.go +++ b/packets.go @@ -281,7 +281,8 @@ func (mc *mysqlConn) initClientCapabilities(serverCapabilities capabilityFlag, c clientLocalFiles | clientPluginAuth | clientMultiResults | - clientConnectAttrs + clientConnectAttrs | + clientDeprecateEOF if cfg.ClientFoundRows { clientCapabilities |= clientFoundRows @@ -709,20 +710,12 @@ func (mc *okHandler) handleOkPacket(data []byte) error { func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; ; i++ { + for i := 0; i < count; i++ { data, err := mc.readPacket() if err != nil { return nil, err } - // EOF Packet - if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { - if i == count { - return columns, nil - } - return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) - } - // Catalog pos := int(data[0]) + 1 // Database [len coded string] @@ -780,13 +773,13 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Decimals [uint8] columns[i].decimals = data[pos] - //pos++ + } - // Default value [len coded binary] - //if pos < len(data) { - // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) - //} + // skip EOF packet if client does not support deprecateEOF + if err := mc.skipEof(); err != nil { + return nil, err } + return columns, nil } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -804,9 +797,16 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) + if data[0] == iEOF && len(data) < 0xffffff { + if mc.clientCapabilities&clientDeprecateEOF == 0 { + // EOF packet + mc.status = readStatus(data[3:]) + } else { + // Ok Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil @@ -826,7 +826,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { ) for i := range dest { - // Read bytes and convert to string + // Read field bytes var buf []byte buf, isNull, n, err = readLengthEncodedBytes(data[pos:]) pos += n @@ -871,6 +871,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { default: dest[i] = buf + continue } if err != nil { return err @@ -880,8 +881,33 @@ func (rows *textRows) readRow(dest []driver.Value) error { return nil } -// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() error { +func (mc *mysqlConn) skipPackets(number int) error { + for i := 0; i < number; i++ { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +func (mc *mysqlConn) skipEof() error { + if mc.clientCapabilities&clientDeprecateEOF == 0 { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +func (mc *mysqlConn) skipColumns(resLen int) error { + if err := mc.skipPackets(resLen); err != nil { + return err + } + return mc.skipEof() +} + +// Reads Packets until EOF-Packet or an Error appears. +func (mc *mysqlConn) skipResultSetRows() error { for { data, err := mc.readPacket() if err != nil { @@ -892,10 +918,18 @@ func (mc *mysqlConn) readUntilEOF() error { case iERR: return mc.handleErrorPacket(data) case iEOF: - if len(data) == 5 { - mc.status = readStatus(data[3:]) + if len(data) < 0xffffff { + if mc.clientCapabilities&clientDeprecateEOF == 0 { + // EOF packet + mc.status = readStatus(data[3:]) + } else { + // OK packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + mc.status = readStatus(data[1+n+m:]) + } + return nil } - return nil } } } @@ -1192,11 +1226,11 @@ func (mc *okHandler) discardResults() error { } if resLen > 0 { // columns - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipColumns(resLen); err != nil { return err } // rows - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipResultSetRows(); err != nil { return err } } @@ -1213,19 +1247,27 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - // EOF Packet - if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) + // EOF/OK Packet + if data[0] == iEOF { + if rows.mc.clientCapabilities&clientDeprecateEOF == 0 { + // EOF packet + rows.mc.status = readStatus(data[3:]) + } else { + // OK Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + rows.mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } - mc := rows.mc - rows.mc = nil // Error otherwise + mc := rows.mc + rows.mc = nil return mc.handleErrorPacket(data) } diff --git a/rows.go b/rows.go index e77cea5c..bfb821dc 100644 --- a/rows.go +++ b/rows.go @@ -113,7 +113,7 @@ func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.rs.done { - err = mc.readUntilEOF() + err = mc.skipResultSetRows() } if err == nil { handleOk := mc.clearResult() @@ -143,7 +143,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { // Remove unread packets from stream if !rows.rs.done { - if err := rows.mc.readUntilEOF(); err != nil { + if err := rows.mc.skipResultSetRows(); err != nil { return 0, err } rows.rs.done = true diff --git a/statement.go b/statement.go index e2f9053f..7c63f1ed 100644 --- a/statement.go +++ b/statement.go @@ -72,12 +72,12 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if resLen > 0 { // Columns - if err = mc.readUntilEOF(); err != nil { + if err = mc.skipColumns(resLen); err != nil { return nil, err } // Rows - if err := mc.readUntilEOF(); err != nil { + if err = mc.skipResultSetRows(); err != nil { return nil, err } } @@ -118,15 +118,12 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if resLen > 0 { rows.mc = mc if metadataFollows { - rows.rs.columns, err = mc.readColumns(resLen) - if err != nil { + if rows.rs.columns, err = mc.readColumns(resLen); err != nil { return nil, err } stmt.columns = rows.rs.columns } else { - // skip EOF Packet - _, err := mc.readPacket() - if err != nil { + if err = mc.skipEof(); err != nil { return nil, err } rows.rs.columns = stmt.columns