From 19076451fd9dcce1752d7efcc57177097a64bef3 Mon Sep 17 00:00:00 2001 From: Diego Dupin Date: Thu, 24 Apr 2025 18:04:04 +0200 Subject: [PATCH 1/6] MariaDB Metadata skipping --- auth_test.go | 28 +++--- benchmark_test.go | 59 +++++++++++- connection.go | 57 ++++++----- connector.go | 10 +- const.go | 16 +++- packets.go | 239 +++++++++++++++++++++++++++++----------------- packets_test.go | 10 +- rows.go | 6 +- statement.go | 21 +++- 9 files changed, 307 insertions(+), 139 deletions(-) diff --git a/auth_test.go b/auth_test.go index 46e1e3b4..a8f1d4bd 100644 --- a/auth_test.go +++ b/auth_test.go @@ -89,7 +89,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -176,7 +176,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -232,7 +232,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -284,7 +284,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -357,7 +357,7 @@ func TestAuthFastCleartextPassword(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -400,7 +400,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -459,7 +459,7 @@ func TestAuthFastNativePassword(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -502,7 +502,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -544,7 +544,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -592,7 +592,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -641,7 +641,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.TLS = nil - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } @@ -1343,7 +1343,7 @@ func TestEd25519Auth(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) if err != nil { t.Fatal(err) } diff --git a/benchmark_test.go b/benchmark_test.go index 1c3f64d3..e9d26a17 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -129,7 +129,7 @@ func BenchmarkExec(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < concurrencyLevel; i++ { + for i := 0; i < concurrencyLevel; i++ { go func() { for { if atomic.AddInt64(&remain, -1) < 0 { @@ -400,7 +400,7 @@ func benchmark10kRows(b *testing.B, compress bool) { } args := make([]any, 200) - for i := 1; i < 200; i+=2 { + for i := 1; i < 200; i += 2 { args[i] = sval } for i := 0; i < 10000; i += 100 { @@ -455,3 +455,58 @@ func BenchmarkReceive10kRows(b *testing.B) { func BenchmarkReceive10kRowsCompressed(b *testing.B) { benchmark10kRows(b, true) } + +// BenchmarkReceiveMetadata measures performance of receiving lots of metadata compare to data in rows +func BenchmarkReceiveMetadata(b *testing.B) { + tb := (*TB)(b) + + // Create a table with 1000 integer fields + createTableQuery := "CREATE TABLE large_integer_table (" + for i := 0; i < 1000; i++ { + createTableQuery += fmt.Sprintf("col_%d INT", i) + if i < 999 { + createTableQuery += ", " + } + } + createTableQuery += ")" + + // Initialize database + db := initDB(b, false, + "DROP TABLE IF EXISTS large_integer_table", + createTableQuery, + "INSERT INTO large_integer_table VALUES ("+ + strings.Repeat("0,", 999)+"0)", // Insert a row of zeros + ) + defer db.Close() + + b.Run("query", func(b *testing.B) { + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) + + // Create a slice to scan all columns + values := make([]interface{}, 1000) + valuePtrs := make([]interface{}, 1000) + for j := range values { + valuePtrs[j] = &values[j] + } + + b.ReportAllocs() + b.ResetTimer() + + // Prepare a SELECT query to retrieve metadata + stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1")) + defer stmt.Close() + + // Benchmark metadata retrieval + for i := 0; i < b.N; i++ { + rows := tb.checkRows(stmt.Query()) + + rows.Next() + // Scan the row + err := rows.Scan(valuePtrs...) + tb.check(err) + + rows.Close() + } + }) +} diff --git a/connection.go b/connection.go index 3e455a3f..cd84f29d 100644 --- a/connection.go +++ b/connection.go @@ -24,21 +24,22 @@ import ( ) type mysqlConn struct { - buf buffer - netConn net.Conn - rawConn net.Conn // underlying connection when netConn is TLS connection. - result mysqlResult // managed by clearResult() and handleOkPacket(). - compIO *compIO - cfg *Config - connector *connector - maxAllowedPacket int - maxWriteSize int - flags clientFlag - status statusFlag - sequence uint8 - compressSequence uint8 - parseTime bool - compress bool + buf buffer + netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. + result mysqlResult // managed by clearResult() and handleOkPacket(). + compIO *compIO + cfg *Config + connector *connector + maxAllowedPacket int + maxWriteSize int + clientCapabilities capabilityFlag + clientExtCapabilities extendedCapabilityFlag + status statusFlag + sequence uint8 + compressSequence uint8 + parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -223,13 +224,21 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { + if err = mc.skipColumns(stmt.paramCount); err != nil { return nil, err } } if columnCount > 0 { - err = mc.readUntilEOF() + if mc.clientExtCapabilities&clientCacheMetadata != 0 { + if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil { + return nil, err + } + } else { + if err = mc.skipColumns(int(columnCount)); err != nil { + return nil, err + } + } } } @@ -370,19 +379,19 @@ func (mc *mysqlConn) exec(query string) error { } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipColumns(resLen); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipResultSetRows(); err != nil { return err } } @@ -419,7 +428,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) // Read Result var resLen int - resLen, err = handleOk.readResultSetHeaderPacket() + resLen, _, err = handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -453,7 +462,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc @@ -461,14 +470,14 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if resLen > 0 { // Columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipColumns(resLen); err != nil { return nil, err } } dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { - return dest[0].([]byte), mc.readUntilEOF() + return dest[0].([]byte), mc.skipResultSetRows() } } return nil, err diff --git a/connector.go b/connector.go index bc1d46af..217239cc 100644 --- a/connector.go +++ b/connector.go @@ -131,12 +131,16 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer() // Reading Handshake Initialization Packet - authData, plugin, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtendedCapabilities, plugin, err := mc.readHandshakePacket() if err != nil { mc.cleanup() return nil, err } + if mc.cfg.TLS != nil && serverCapabilities&clientSSL == 0 { + return nil, fmt.Errorf("TLS is required, but server doesn't support it") + } + if plugin == "" { plugin = defaultAuthPlugin } @@ -153,7 +157,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + if err = mc.writeHandshakeResponsePacket(authResp, serverCapabilities, serverExtendedCapabilities, plugin); err != nil { mc.cleanup() return nil, err } @@ -167,7 +171,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + if mc.cfg.compress && mc.clientCapabilities&clientCompress > 0 { mc.compress = true mc.compIO = newCompIO(mc) } diff --git a/const.go b/const.go index 4aadcd64..e451d4a2 100644 --- a/const.go +++ b/const.go @@ -43,10 +43,10 @@ const ( ) // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags -type clientFlag uint32 +type capabilityFlag uint32 const ( - clientLongPassword clientFlag = 1 << iota + clientMySQL capabilityFlag = 1 << iota clientFoundRows clientLongFlag clientConnectWithDB @@ -73,6 +73,18 @@ const ( clientDeprecateEOF ) +// https://mariadb.com/kb/en/connection/#capabilities +type extendedCapabilityFlag uint32 + +const ( + progressIndicator extendedCapabilityFlag = 1 << iota + clientComMulti + clientStmtBulkOperations + clientExtendedMetadata + clientCacheMetadata + clientUnitBulkResult +) + const ( comQuit byte = iota + 1 comInitDB diff --git a/packets.go b/packets.go index e6e1704b..fc10113f 100644 --- a/packets.go +++ b/packets.go @@ -185,19 +185,19 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { +func (mc *mysqlConn) readHandshakePacket() (data []byte, serverCapabilities capabilityFlag, serverExtendedCapabilities extendedCapabilityFlag, plugin string, err error) { data, err = mc.readPacket() if err != nil { return } if data[0] == iERR { - return nil, "", mc.handleErrorPacket(data) + return nil, 0, 0, "", mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, "", fmt.Errorf( + return nil, 0, 0, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -215,15 +215,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro pos += 8 + 1 // capability flags (lower 2 bytes) [2 bytes] - mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) - if mc.flags&clientProtocol41 == 0 { - return nil, "", ErrOldProtocol + serverCapabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if serverCapabilities&clientProtocol41 == 0 { + return nil, serverCapabilities, 0, "", ErrOldProtocol } - if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { + if serverCapabilities&clientSSL == 0 && mc.cfg.TLS != nil { if mc.cfg.AllowFallbackToPlaintext { mc.cfg.TLS = nil } else { - return nil, "", ErrNoTLS + return nil, serverCapabilities, 0, "", ErrNoTLS } } pos += 2 @@ -233,11 +233,16 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // status flags [2 bytes] pos += 3 // capability flags (upper 2 bytes) [2 bytes] - mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + serverCapabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 pos += 2 // length of auth-plugin-data [1 byte] - // reserved (all [00]) [10 bytes] - pos += 11 + // reserved (all [00]) [6 bytes] + pos += 7 + if serverCapabilities&clientMySQL == 0 { + // MariaDB server extended flag + serverExtendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4])) + } + pos += 4 // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -265,63 +270,73 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // make a memory safe copy of the cipher slice var b [20]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], serverCapabilities, serverExtendedCapabilities, plugin, nil } // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], serverCapabilities, 0, plugin, nil } -// Client Authentication Packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { - // Adjust client flags based on server support - clientFlags := clientProtocol41 | - clientSecureConn | - clientLongPassword | - clientTransactions | - clientLocalFiles | - clientPluginAuth | - clientMultiResults | - mc.flags&clientConnectAttrs | - mc.flags&clientLongFlag - - sendConnectAttrs := mc.flags&clientConnectAttrs != 0 - - if mc.cfg.ClientFoundRows { - clientFlags |= clientFoundRows - } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { - clientFlags |= clientCompress +// initClientCapabilities initializes the client capabilities based on server support and configuration +func (mc *mysqlConn) initClientCapabilities(serverCapabilities capabilityFlag, cfg *Config) capabilityFlag { + + clientCapabilities := + clientMySQL | + clientLongFlag | + clientProtocol41 | + clientSecureConn | + clientTransactions | + clientPluginAuthLenEncClientData | + clientLocalFiles | + clientPluginAuth | + clientMultiResults | + clientConnectAttrs | + clientDeprecateEOF + + if cfg.ClientFoundRows { + clientCapabilities |= clientFoundRows + } + if cfg.compress { + clientCapabilities |= clientCompress } // To enable TLS / SSL if mc.cfg.TLS != nil { - clientFlags |= clientSSL + clientCapabilities |= clientSSL } if mc.cfg.MultiStatements { - clientFlags |= clientMultiStatements + clientCapabilities |= clientMultiStatements + } + if n := len(cfg.DBName); n > 0 { + clientCapabilities |= clientConnectWithDB } + // only keep client capabilities that server have + return clientCapabilities & serverCapabilities +} + +// Client Authentication Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, serverCapabilities capabilityFlag, serverExtendedCapabilities extendedCapabilityFlag, plugin string) error { + // Adjust client capabilities based on server support + mc.clientCapabilities = mc.initClientCapabilities(serverCapabilities, mc.cfg) + + // set MariaDB extended clientCacheMetadata capability if server support it + mc.clientExtCapabilities = clientCacheMetadata & serverExtendedCapabilities + + sendConnectAttrs := mc.clientCapabilities&clientConnectAttrs != 0 + // encode length of the auth plugin data var authRespLEIBuf [9]byte authRespLen := len(authResp) authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) - if len(authRespLEI) > 1 { - // if the length can not be written in 1 byte, it must be written as a - // length encoded integer - clientFlags |= clientPluginAuthLenEncClientData - } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 // To specify a db name - if n := len(mc.cfg.DBName); n > 0 { - clientFlags |= clientConnectWithDB - pktLen += n + 1 - } + pktLen += len(mc.cfg.DBName) + 1 // encode length of the connection attributes var connAttrsLEI []byte @@ -339,8 +354,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return err } - // ClientFlags [32 bit] - binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags)) + // clientCapabilities [32 bit] + binary.LittleEndian.PutUint32(data[4:], uint32(mc.clientCapabilities)) // MaxPacketSize [32 bit] (none) binary.LittleEndian.PutUint32(data[8:], 0) @@ -359,9 +374,18 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // Filler [23 bytes] (all 0x00) pos := 13 - for ; pos < 13+23; pos++ { + for ; pos < 13+19; pos++ { data[pos] = 0 } + if mc.clientCapabilities&clientMySQL == 0 { + // MariaDB Extended Capabilities + binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.clientExtCapabilities)) + pos += 4 + } else { + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + } // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest @@ -393,8 +417,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pos += copy(data[pos:], authRespLEI) pos += copy(data[pos:], authResp) - // Databasename [null terminated string] - if len(mc.cfg.DBName) > 0 { + // Database name [null terminated string] + if mc.clientCapabilities&clientConnectWithDB != 0 { pos += copy(data[pos:], mc.cfg.DBName) data[pos] = 0x00 pos++ @@ -546,32 +570,37 @@ func (mc *okHandler) readResultOK() error { // Result Set Header Packet // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html -func (mc *okHandler) readResultSetHeaderPacket() (int, error) { +func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) { // handleOkPacket replaces both values; other cases leave the values unchanged. mc.result.affectedRows = append(mc.result.affectedRows, 0) mc.result.insertIds = append(mc.result.insertIds, 0) data, err := mc.conn().readPacket() if err != nil { - return 0, err + return 0, false, err } switch data[0] { case iOK: - return 0, mc.handleOkPacket(data) + return 0, false, mc.handleOkPacket(data) case iERR: - return 0, mc.conn().handleErrorPacket(data) + return 0, false, mc.conn().handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) + return 0, false, mc.handleInFileRequest(string(data[1:])) } // column count // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html - num, _, _ := readLengthEncodedInteger(data) + // https://mariadb.com/kb/en/result-set-packets/#column-count-packet + num, _, len := readLengthEncodedInteger(data) + + if mc.clientExtCapabilities&clientCacheMetadata != 0 { + return int(num), data[len] == 0x01, nil + } // ignore remaining data in the packet. see #1478. - return int(num), nil + return int(num), true, nil } // Error Packet @@ -695,20 +724,12 @@ func (mc *okHandler) handleOkPacket(data []byte) error { func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; ; i++ { + for i := 0; i < count; i++ { data, err := mc.readPacket() if err != nil { return nil, err } - // EOF Packet - if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { - if i == count { - return columns, nil - } - return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) - } - // Catalog pos, err := skipLengthEncodedString(data) if err != nil { @@ -781,13 +802,13 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Decimals [uint8] columns[i].decimals = data[pos] - //pos++ + } - // Default value [len coded binary] - //if pos < len(data) { - // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) - //} + // skip EOF packet if client does not support deprecateEOF + if err := mc.skipEof(); err != nil { + return nil, err } + return columns, nil } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -805,9 +826,16 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) + if data[0] == iEOF && len(data) < 0xffffff { + if mc.clientCapabilities&clientDeprecateEOF == 0 { + // EOF packet + mc.status = readStatus(data[3:]) + } else { + // Ok Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil @@ -881,8 +909,33 @@ func (rows *textRows) readRow(dest []driver.Value) error { return nil } -// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() error { +func (mc *mysqlConn) skipPackets(number int) error { + for i := 0; i < number; i++ { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +func (mc *mysqlConn) skipEof() error { + if mc.clientCapabilities&clientDeprecateEOF == 0 { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +func (mc *mysqlConn) skipColumns(resLen int) error { + if err := mc.skipPackets(resLen); err != nil { + return err + } + return mc.skipEof() +} + +// Reads Packets until EOF-Packet or an Error appears. +func (mc *mysqlConn) skipResultSetRows() error { for { data, err := mc.readPacket() if err != nil { @@ -893,10 +946,18 @@ func (mc *mysqlConn) readUntilEOF() error { case iERR: return mc.handleErrorPacket(data) case iEOF: - if len(data) == 5 { - mc.status = readStatus(data[3:]) + if len(data) < 0xffffff { + if mc.clientCapabilities&clientDeprecateEOF == 0 { + // EOF packet + mc.status = readStatus(data[3:]) + } else { + // OK packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + mc.status = readStatus(data[1+n+m:]) + } + return nil } - return nil } } } @@ -1184,17 +1245,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // mc.affectedRows and mc.insertIds. func (mc *okHandler) discardResults() error { for mc.status&statusMoreResultsExists != 0 { - resLen, err := mc.readResultSetHeaderPacket() + resLen, _, err := mc.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipColumns(resLen); err != nil { return err } // rows - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipResultSetRows(); err != nil { return err } } @@ -1211,9 +1272,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - // EOF Packet - if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) + // EOF/OK Packet + if data[0] == iEOF { + if rows.mc.clientCapabilities&clientDeprecateEOF == 0 { + // EOF packet + rows.mc.status = readStatus(data[3:]) + } else { + // OK Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + rows.mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil diff --git a/packets_test.go b/packets_test.go index 694b0564..b487051e 100644 --- a/packets_test.go +++ b/packets_test.go @@ -332,11 +332,19 @@ func TestRegression801(t *testing.T) { 112, 97, 115, 115, 119, 111, 114, 100} conn.maxReads = 1 - authData, pluginName, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtendedCapabilities, pluginName, err := mc.readHandshakePacket() if err != nil { t.Fatalf("got error: %v", err) } + if serverCapabilities != 2148530143 { + t.Fatalf("expected serverCapabilities to be 2148530143, got %v", serverCapabilities) + } + + if serverExtendedCapabilities != 0 { + t.Fatalf("expected serverExtendedCapabilities to be 0, got %v", serverExtendedCapabilities) + } + if pluginName != "mysql_native_password" { t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName) } diff --git a/rows.go b/rows.go index df98417b..bfb821dc 100644 --- a/rows.go +++ b/rows.go @@ -113,7 +113,7 @@ func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.rs.done { - err = mc.readUntilEOF() + err = mc.skipResultSetRows() } if err == nil { handleOk := mc.clearResult() @@ -143,7 +143,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { // Remove unread packets from stream if !rows.rs.done { - if err := rows.mc.readUntilEOF(); err != nil { + if err := rows.mc.skipResultSetRows(); err != nil { return 0, err } rows.rs.done = true @@ -156,7 +156,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { rows.rs = resultSet{} // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to // nextResultSet. - resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() + resLen, _, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() if err != nil { // Clean up about multi-results flag rows.rs.done = true diff --git a/statement.go b/statement.go index 35df8545..7c63f1ed 100644 --- a/statement.go +++ b/statement.go @@ -20,6 +20,7 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int + columns []mysqlField } func (stmt *mysqlStmt) Close() error { @@ -64,19 +65,19 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { handleOk := stmt.mc.clearResult() // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } if resLen > 0 { // Columns - if err = mc.readUntilEOF(); err != nil { + if err = mc.skipColumns(resLen); err != nil { return nil, err } // Rows - if err := mc.readUntilEOF(); err != nil { + if err = mc.skipResultSetRows(); err != nil { return nil, err } } @@ -107,7 +108,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Read Result handleOk := stmt.mc.clearResult() - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -116,7 +117,17 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if resLen > 0 { rows.mc = mc - rows.rs.columns, err = mc.readColumns(resLen) + if metadataFollows { + if rows.rs.columns, err = mc.readColumns(resLen); err != nil { + return nil, err + } + stmt.columns = rows.rs.columns + } else { + if err = mc.skipEof(); err != nil { + return nil, err + } + rows.rs.columns = stmt.columns + } } else { rows.rs.done = true From 9db718f94427f80e979c9d1819346692a4e8b9df Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 25 Apr 2025 17:55:40 +0900 Subject: [PATCH 2/6] simplify --- auth_test.go | 28 +++++------ connection.go | 34 ++++++------- connector.go | 8 +-- const.go | 3 +- packets.go | 135 ++++++++++++++++++++++---------------------------- 5 files changed, 97 insertions(+), 111 deletions(-) diff --git a/auth_test.go b/auth_test.go index a8f1d4bd..46e1e3b4 100644 --- a/auth_test.go +++ b/auth_test.go @@ -89,7 +89,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -176,7 +176,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -232,7 +232,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -284,7 +284,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -357,7 +357,7 @@ func TestAuthFastCleartextPassword(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -400,7 +400,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -459,7 +459,7 @@ func TestAuthFastNativePassword(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -502,7 +502,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -544,7 +544,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -592,7 +592,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -641,7 +641,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.TLS = nil - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -1343,7 +1343,7 @@ func TestEd25519Auth(t *testing.T) { if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } diff --git a/connection.go b/connection.go index cd84f29d..6024d45f 100644 --- a/connection.go +++ b/connection.go @@ -24,22 +24,22 @@ import ( ) type mysqlConn struct { - buf buffer - netConn net.Conn - rawConn net.Conn // underlying connection when netConn is TLS connection. - result mysqlResult // managed by clearResult() and handleOkPacket(). - compIO *compIO - cfg *Config - connector *connector - maxAllowedPacket int - maxWriteSize int - clientCapabilities capabilityFlag - clientExtCapabilities extendedCapabilityFlag - status statusFlag - sequence uint8 - compressSequence uint8 - parseTime bool - compress bool + buf buffer + netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. + result mysqlResult // managed by clearResult() and handleOkPacket(). + compIO *compIO + cfg *Config + connector *connector + maxAllowedPacket int + maxWriteSize int + capabilities capabilityFlag + extCapabilities extendedCapabilityFlag + status statusFlag + sequence uint8 + compressSequence uint8 + parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -230,7 +230,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { } if columnCount > 0 { - if mc.clientExtCapabilities&clientCacheMetadata != 0 { + if mc.extCapabilities&clientCacheMetadata != 0 { if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil { return nil, err } diff --git a/connector.go b/connector.go index 217239cc..9650cc4f 100644 --- a/connector.go +++ b/connector.go @@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer() // Reading Handshake Initialization Packet - authData, serverCapabilities, serverExtendedCapabilities, plugin, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtCapabilities, plugin, err := mc.readHandshakePacket() if err != nil { mc.cleanup() return nil, err @@ -157,7 +157,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, serverCapabilities, serverExtendedCapabilities, plugin); err != nil { + mc.initCapabilities(serverCapabilities, serverExtCapabilities, mc.cfg) + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { mc.cleanup() return nil, err } @@ -171,7 +172,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if mc.cfg.compress && mc.clientCapabilities&clientCompress > 0 { + // compression is enabled after auth, not right after sending handshake response. + if mc.capabilities&clientCompress > 0 { mc.compress = true mc.compIO = newCompIO(mc) } diff --git a/const.go b/const.go index e451d4a2..311e92ea 100644 --- a/const.go +++ b/const.go @@ -42,7 +42,8 @@ const ( iERR byte = 0xff ) -// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags +// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html +// https://mariadb.com/kb/en/connection/#capabilities type capabilityFlag uint32 const ( diff --git a/packets.go b/packets.go index fc10113f..4cb84d77 100644 --- a/packets.go +++ b/packets.go @@ -184,15 +184,17 @@ func (mc *mysqlConn) writePacket(data []byte) error { ******************************************************************************/ // Handshake Initialization Packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() (data []byte, serverCapabilities capabilityFlag, serverExtendedCapabilities extendedCapabilityFlag, plugin string, err error) { +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html +// https://mariadb.com/kb/en/connection/#initial-handshake-packet +func (mc *mysqlConn) readHandshakePacket() (data []byte, capabilities capabilityFlag, extendedCapabilities extendedCapabilityFlag, plugin string, err error) { data, err = mc.readPacket() if err != nil { return } if data[0] == iERR { - return nil, 0, 0, "", mc.handleErrorPacket(data) + err = mc.handleErrorPacket(data) + return } // protocol version [1 byte] @@ -215,15 +217,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, serverCapabilities capa pos += 8 + 1 // capability flags (lower 2 bytes) [2 bytes] - serverCapabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) - if serverCapabilities&clientProtocol41 == 0 { - return nil, serverCapabilities, 0, "", ErrOldProtocol + capabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if capabilities&clientProtocol41 == 0 { + return nil, capabilities, 0, "", ErrOldProtocol } - if serverCapabilities&clientSSL == 0 && mc.cfg.TLS != nil { + if capabilities&clientSSL == 0 && mc.cfg.TLS != nil { if mc.cfg.AllowFallbackToPlaintext { mc.cfg.TLS = nil } else { - return nil, serverCapabilities, 0, "", ErrNoTLS + return nil, capabilities, 0, "", ErrNoTLS } } pos += 2 @@ -233,14 +235,14 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, serverCapabilities capa // status flags [2 bytes] pos += 3 // capability flags (upper 2 bytes) [2 bytes] - serverCapabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + capabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 pos += 2 // length of auth-plugin-data [1 byte] // reserved (all [00]) [6 bytes] pos += 7 - if serverCapabilities&clientMySQL == 0 { + if capabilities&clientMySQL == 0 { // MariaDB server extended flag - serverExtendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4])) + extendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4])) } pos += 4 @@ -270,18 +272,17 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, serverCapabilities capa // make a memory safe copy of the cipher slice var b [20]byte copy(b[:], authData) - return b[:], serverCapabilities, serverExtendedCapabilities, plugin, nil + return b[:], capabilities, extendedCapabilities, plugin, nil } // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], authData) - return b[:], serverCapabilities, 0, plugin, nil + return b[:], capabilities, 0, plugin, nil } -// initClientCapabilities initializes the client capabilities based on server support and configuration -func (mc *mysqlConn) initClientCapabilities(serverCapabilities capabilityFlag, cfg *Config) capabilityFlag { - +// initCapabilities initializes the capabilities based on server support and configuration +func (mc *mysqlConn) initCapabilities(serverCapabilities capabilityFlag, serverExtCapabilities extendedCapabilityFlag, cfg *Config) { clientCapabilities := clientMySQL | clientLongFlag | @@ -314,48 +315,29 @@ func (mc *mysqlConn) initClientCapabilities(serverCapabilities capabilityFlag, c } // only keep client capabilities that server have - return clientCapabilities & serverCapabilities + mc.capabilities = clientCapabilities & serverCapabilities + + // set MariaDB extended clientCacheMetadata capability if server support it + mc.extCapabilities = clientCacheMetadata & serverExtCapabilities } // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, serverCapabilities capabilityFlag, serverExtendedCapabilities extendedCapabilityFlag, plugin string) error { - // Adjust client capabilities based on server support - mc.clientCapabilities = mc.initClientCapabilities(serverCapabilities, mc.cfg) - - // set MariaDB extended clientCacheMetadata capability if server support it - mc.clientExtCapabilities = clientCacheMetadata & serverExtendedCapabilities - - sendConnectAttrs := mc.clientCapabilities&clientConnectAttrs != 0 - - // encode length of the auth plugin data - var authRespLEIBuf [9]byte - authRespLen := len(authResp) - authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) - - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 - - // To specify a db name - pktLen += len(mc.cfg.DBName) + 1 - - // encode length of the connection attributes - var connAttrsLEI []byte - if sendConnectAttrs { - var connAttrsLEIBuf [9]byte - connAttrsLen := len(mc.connector.encodedAttributes) - connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) - pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) - } - - // Calculate packet length and get buffer with that size - data, err := mc.buf.takeBuffer(pktLen + 4) +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { + // packet header 4 + // capabilities 4 + // maxPacketSize 4 + // collation id 1 + // filler 23 + data, err := mc.buf.takeSmallBuffer(4*3 + 24) if err != nil { mc.cleanup() return err } + _ = data[4*3+23] // boundery check // clientCapabilities [32 bit] - binary.LittleEndian.PutUint32(data[4:], uint32(mc.clientCapabilities)) + binary.LittleEndian.PutUint32(data[4:], uint32(mc.capabilities)) // MaxPacketSize [32 bit] (none) binary.LittleEndian.PutUint32(data[8:], 0) @@ -373,25 +355,28 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, serverCapabil } // Filler [23 bytes] (all 0x00) + // or filler 19bytes + mariadb extCapabilities pos := 13 - for ; pos < 13+19; pos++ { - data[pos] = 0 - } - if mc.clientCapabilities&clientMySQL == 0 { + if mc.capabilities&clientMySQL == 0 { + for ; pos < 13+19; pos++ { + data[pos] = 0 + } // MariaDB Extended Capabilities - binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.clientExtCapabilities)) + binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.extCapabilities)) pos += 4 } else { for ; pos < 13+23; pos++ { data[pos] = 0 } } + // assert len(data) == pos // SSL Connection Request Packet - // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html + // https://mariadb.com/kb/en/connection/#sslrequest-packet if mc.cfg.TLS != nil { // Send TLS / SSL request packet - if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + if err := mc.writePacket(data); err != nil { return err } @@ -408,34 +393,32 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, serverCapabil // User [null terminated string] if len(mc.cfg.User) > 0 { - pos += copy(data[pos:], mc.cfg.User) + data = append(data, mc.cfg.User...) } - data[pos] = 0x00 - pos++ + data = append(data, 0) // Auth Data [length encoded integer] - pos += copy(data[pos:], authRespLEI) - pos += copy(data[pos:], authResp) + data = appendLengthEncodedInteger(data, uint64(len(authResp))) + data = append(data, authResp...) // Database name [null terminated string] - if mc.clientCapabilities&clientConnectWithDB != 0 { - pos += copy(data[pos:], mc.cfg.DBName) - data[pos] = 0x00 - pos++ + if mc.capabilities&clientConnectWithDB != 0 { + data = append(data, mc.cfg.DBName...) + data = append(data, 0) } - pos += copy(data[pos:], plugin) - data[pos] = 0x00 - pos++ + data = append(data, plugin...) + data = append(data, 0) // Connection Attributes - if sendConnectAttrs { - pos += copy(data[pos:], connAttrsLEI) - pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + if mc.capabilities&clientConnectAttrs != 0 { + connAttrsLen := len(mc.connector.encodedAttributes) + data = appendLengthEncodedInteger(data, uint64(connAttrsLen)) + data = append(data, mc.connector.encodedAttributes...) } // Send Auth packet - return mc.writePacket(data[:pos]) + return mc.writePacket(data) } // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse @@ -596,7 +579,7 @@ func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) { // https://mariadb.com/kb/en/result-set-packets/#column-count-packet num, _, len := readLengthEncodedInteger(data) - if mc.clientExtCapabilities&clientCacheMetadata != 0 { + if mc.extCapabilities&clientCacheMetadata != 0 { return int(num), data[len] == 0x01, nil } // ignore remaining data in the packet. see #1478. @@ -827,7 +810,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) < 0xffffff { - if mc.clientCapabilities&clientDeprecateEOF == 0 { + if mc.capabilities&clientDeprecateEOF == 0 { // EOF packet mc.status = readStatus(data[3:]) } else { @@ -919,7 +902,7 @@ func (mc *mysqlConn) skipPackets(number int) error { } func (mc *mysqlConn) skipEof() error { - if mc.clientCapabilities&clientDeprecateEOF == 0 { + if mc.capabilities&clientDeprecateEOF == 0 { if _, err := mc.readPacket(); err != nil { return err } @@ -947,7 +930,7 @@ func (mc *mysqlConn) skipResultSetRows() error { return mc.handleErrorPacket(data) case iEOF: if len(data) < 0xffffff { - if mc.clientCapabilities&clientDeprecateEOF == 0 { + if mc.capabilities&clientDeprecateEOF == 0 { // EOF packet mc.status = readStatus(data[3:]) } else { @@ -1274,7 +1257,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { if data[0] != iOK { // EOF/OK Packet if data[0] == iEOF { - if rows.mc.clientCapabilities&clientDeprecateEOF == 0 { + if rows.mc.capabilities&clientDeprecateEOF == 0 { // EOF packet rows.mc.status = readStatus(data[3:]) } else { From 19dc2f7544cd310c7c183c1a8805eef333ad439c Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 25 Apr 2025 19:54:18 +0900 Subject: [PATCH 3/6] shorter names --- connection.go | 4 ++-- packets.go | 32 +++++++++++++++++++------------- rows.go | 4 ++-- statement.go | 2 +- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/connection.go b/connection.go index 6024d45f..58c763fa 100644 --- a/connection.go +++ b/connection.go @@ -391,7 +391,7 @@ func (mc *mysqlConn) exec(query string) error { } // rows - if err := mc.skipResultSetRows(); err != nil { + if err := mc.skipRows(); err != nil { return err } } @@ -477,7 +477,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { - return dest[0].([]byte), mc.skipResultSetRows() + return dest[0].([]byte), mc.skipRows() } } return nil, err diff --git a/packets.go b/packets.go index 4cb84d77..da755aea 100644 --- a/packets.go +++ b/packets.go @@ -809,14 +809,18 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) < 0xffffff { + // text row packets may starts with LengthEncodedString. + // In such case, 0xFE can mean string larger than 0xffffff. + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le + if data[0] == iEOF && len(data) <= 0xffffff { if mc.capabilities&clientDeprecateEOF == 0 { - // EOF packet + // Deprecated EOF packet + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_eof_packet.html mc.status = readStatus(data[3:]) } else { // Ok Packet with an 0xFE header - _, _, n := readLengthEncodedInteger(data[1:]) - _, _, m := readLengthEncodedInteger(data[1+n:]) + _, _, n := readLengthEncodedInteger(data[1:]) // affected_rows + _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id mc.status = readStatus(data[1+n+m:]) } rows.rs.done = true @@ -892,8 +896,8 @@ func (rows *textRows) readRow(dest []driver.Value) error { return nil } -func (mc *mysqlConn) skipPackets(number int) error { - for i := 0; i < number; i++ { +func (mc *mysqlConn) skipPackets(n int) error { + for i := 0; i < n; i++ { if _, err := mc.readPacket(); err != nil { return err } @@ -910,15 +914,15 @@ func (mc *mysqlConn) skipEof() error { return nil } -func (mc *mysqlConn) skipColumns(resLen int) error { - if err := mc.skipPackets(resLen); err != nil { +func (mc *mysqlConn) skipColumns(n int) error { + if err := mc.skipPackets(n); err != nil { return err } return mc.skipEof() } // Reads Packets until EOF-Packet or an Error appears. -func (mc *mysqlConn) skipResultSetRows() error { +func (mc *mysqlConn) skipRows() error { for { data, err := mc.readPacket() if err != nil { @@ -929,14 +933,16 @@ func (mc *mysqlConn) skipResultSetRows() error { case iERR: return mc.handleErrorPacket(data) case iEOF: - if len(data) < 0xffffff { + // text row packets may starts with LengthEncodedString. + // In such case, 0xFE can mean string larger than 0xffffff. + if len(data) <= 0xffffff { if mc.capabilities&clientDeprecateEOF == 0 { // EOF packet mc.status = readStatus(data[3:]) } else { // OK packet with an 0xFE header - _, _, n := readLengthEncodedInteger(data[1:]) - _, _, m := readLengthEncodedInteger(data[1+n:]) + _, _, n := readLengthEncodedInteger(data[1:]) // affected_rows + _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id mc.status = readStatus(data[1+n+m:]) } return nil @@ -1238,7 +1244,7 @@ func (mc *okHandler) discardResults() error { return err } // rows - if err := mc.conn().skipResultSetRows(); err != nil { + if err := mc.conn().skipRows(); err != nil { return err } } diff --git a/rows.go b/rows.go index bfb821dc..e41fda6f 100644 --- a/rows.go +++ b/rows.go @@ -113,7 +113,7 @@ func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.rs.done { - err = mc.skipResultSetRows() + err = mc.skipRows() } if err == nil { handleOk := mc.clearResult() @@ -143,7 +143,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { // Remove unread packets from stream if !rows.rs.done { - if err := rows.mc.skipResultSetRows(); err != nil { + if err := rows.mc.skipRows(); err != nil { return 0, err } rows.rs.done = true diff --git a/statement.go b/statement.go index 7c63f1ed..0a2f0c42 100644 --- a/statement.go +++ b/statement.go @@ -77,7 +77,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } // Rows - if err = mc.skipResultSetRows(); err != nil { + if err = mc.skipRows(); err != nil { return nil, err } } From 90db6835b395592ac009a1d201956268ea394854 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 25 Apr 2025 20:15:23 +0900 Subject: [PATCH 4/6] fix review points --- packets.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/packets.go b/packets.go index da755aea..1319f9e6 100644 --- a/packets.go +++ b/packets.go @@ -363,13 +363,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // MariaDB Extended Capabilities binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.extCapabilities)) - pos += 4 } else { for ; pos < 13+23; pos++ { data[pos] = 0 } } - // assert len(data) == pos // SSL Connection Request Packet // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html @@ -707,7 +705,7 @@ func (mc *okHandler) handleOkPacket(data []byte) error { func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; i < count; i++ { + for i := range count { data, err := mc.readPacket() if err != nil { return nil, err @@ -905,6 +903,7 @@ func (mc *mysqlConn) skipPackets(n int) error { return nil } +// skips EOF packet after n * ColumnDefinition packets when clientDeprecateEOF is not set func (mc *mysqlConn) skipEof() error { if mc.capabilities&clientDeprecateEOF == 0 { if _, err := mc.readPacket(); err != nil { From ccdef72e2bccc07a638ed243b56377a8ce8a0d99 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 25 Apr 2025 20:46:49 +0900 Subject: [PATCH 5/6] stmt.Exec() should cache metadata --- statement.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/statement.go b/statement.go index 0a2f0c42..0f6c65a3 100644 --- a/statement.go +++ b/statement.go @@ -65,15 +65,22 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { handleOk := stmt.mc.clearResult() // Read Result - resLen, _, err := handleOk.readResultSetHeaderPacket() + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } if resLen > 0 { // Columns - if err = mc.skipColumns(resLen); err != nil { - return nil, err + if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 { + // we can not skip column metadata because next stmt.Query() may use it. + if stmt.columns, err = mc.readColumns(resLen); err != nil { + return nil, err + } + } else { + if err = mc.skipColumns(resLen); err != nil { + return nil, err + } } // Rows From 1ef9cdec11c8d720fade5c2c6bd105c067f64f94 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 25 Apr 2025 20:51:46 +0900 Subject: [PATCH 6/6] modernize BenchmarkReceiveMetadata --- benchmark_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index e9d26a17..b246f4ac 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -484,8 +484,8 @@ func BenchmarkReceiveMetadata(b *testing.B) { db.SetMaxIdleConns(1) // Create a slice to scan all columns - values := make([]interface{}, 1000) - valuePtrs := make([]interface{}, 1000) + values := make([]any, 1000) + valuePtrs := make([]any, 1000) for j := range values { valuePtrs[j] = &values[j] } @@ -498,7 +498,7 @@ func BenchmarkReceiveMetadata(b *testing.B) { defer stmt.Close() // Benchmark metadata retrieval - for i := 0; i < b.N; i++ { + for range b.N { rows := tb.checkRows(stmt.Query()) rows.Next()