Skip to content

Commit 1907645

Browse files
committed
MariaDB Metadata skipping
1 parent 0fd55eb commit 1907645

9 files changed

+307
-139
lines changed

auth_test.go

+14-14
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
8989
if err != nil {
9090
t.Fatal(err)
9191
}
92-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
92+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
9393
if err != nil {
9494
t.Fatal(err)
9595
}
@@ -134,7 +134,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
134134
if err != nil {
135135
t.Fatal(err)
136136
}
137-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
137+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
138138
if err != nil {
139139
t.Fatal(err)
140140
}
@@ -176,7 +176,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
176176
if err != nil {
177177
t.Fatal(err)
178178
}
179-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
179+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
180180
if err != nil {
181181
t.Fatal(err)
182182
}
@@ -232,7 +232,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
232232
if err != nil {
233233
t.Fatal(err)
234234
}
235-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
235+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
236236
if err != nil {
237237
t.Fatal(err)
238238
}
@@ -284,7 +284,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
284284
if err != nil {
285285
t.Fatal(err)
286286
}
287-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
287+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
288288
if err != nil {
289289
t.Fatal(err)
290290
}
@@ -357,7 +357,7 @@ func TestAuthFastCleartextPassword(t *testing.T) {
357357
if err != nil {
358358
t.Fatal(err)
359359
}
360-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
360+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
361361
if err != nil {
362362
t.Fatal(err)
363363
}
@@ -400,7 +400,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
400400
if err != nil {
401401
t.Fatal(err)
402402
}
403-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
403+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
404404
if err != nil {
405405
t.Fatal(err)
406406
}
@@ -459,7 +459,7 @@ func TestAuthFastNativePassword(t *testing.T) {
459459
if err != nil {
460460
t.Fatal(err)
461461
}
462-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
462+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
463463
if err != nil {
464464
t.Fatal(err)
465465
}
@@ -502,7 +502,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) {
502502
if err != nil {
503503
t.Fatal(err)
504504
}
505-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
505+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
506506
if err != nil {
507507
t.Fatal(err)
508508
}
@@ -544,7 +544,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
544544
if err != nil {
545545
t.Fatal(err)
546546
}
547-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
547+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
548548
if err != nil {
549549
t.Fatal(err)
550550
}
@@ -592,7 +592,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) {
592592
if err != nil {
593593
t.Fatal(err)
594594
}
595-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
595+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
596596
if err != nil {
597597
t.Fatal(err)
598598
}
@@ -641,7 +641,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
641641
if err != nil {
642642
t.Fatal(err)
643643
}
644-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
644+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
645645
if err != nil {
646646
t.Fatal(err)
647647
}
@@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
678678
// unset TLS config to prevent the actual establishment of a TLS wrapper
679679
mc.cfg.TLS = nil
680680

681-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
681+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
682682
if err != nil {
683683
t.Fatal(err)
684684
}
@@ -1343,7 +1343,7 @@ func TestEd25519Auth(t *testing.T) {
13431343
if err != nil {
13441344
t.Fatal(err)
13451345
}
1346-
err = mc.writeHandshakeResponsePacket(authResp, plugin)
1346+
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
13471347
if err != nil {
13481348
t.Fatal(err)
13491349
}

benchmark_test.go

+57-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func BenchmarkExec(b *testing.B) {
129129
b.ReportAllocs()
130130
b.ResetTimer()
131131

132-
for i := 0; i < concurrencyLevel; i++ {
132+
for i := 0; i < concurrencyLevel; i++ {
133133
go func() {
134134
for {
135135
if atomic.AddInt64(&remain, -1) < 0 {
@@ -400,7 +400,7 @@ func benchmark10kRows(b *testing.B, compress bool) {
400400
}
401401

402402
args := make([]any, 200)
403-
for i := 1; i < 200; i+=2 {
403+
for i := 1; i < 200; i += 2 {
404404
args[i] = sval
405405
}
406406
for i := 0; i < 10000; i += 100 {
@@ -455,3 +455,58 @@ func BenchmarkReceive10kRows(b *testing.B) {
455455
func BenchmarkReceive10kRowsCompressed(b *testing.B) {
456456
benchmark10kRows(b, true)
457457
}
458+
459+
// BenchmarkReceiveMetadata measures performance of receiving lots of metadata compare to data in rows
460+
func BenchmarkReceiveMetadata(b *testing.B) {
461+
tb := (*TB)(b)
462+
463+
// Create a table with 1000 integer fields
464+
createTableQuery := "CREATE TABLE large_integer_table ("
465+
for i := 0; i < 1000; i++ {
466+
createTableQuery += fmt.Sprintf("col_%d INT", i)
467+
if i < 999 {
468+
createTableQuery += ", "
469+
}
470+
}
471+
createTableQuery += ")"
472+
473+
// Initialize database
474+
db := initDB(b, false,
475+
"DROP TABLE IF EXISTS large_integer_table",
476+
createTableQuery,
477+
"INSERT INTO large_integer_table VALUES ("+
478+
strings.Repeat("0,", 999)+"0)", // Insert a row of zeros
479+
)
480+
defer db.Close()
481+
482+
b.Run("query", func(b *testing.B) {
483+
db.SetMaxIdleConns(0)
484+
db.SetMaxIdleConns(1)
485+
486+
// Create a slice to scan all columns
487+
values := make([]interface{}, 1000)
488+
valuePtrs := make([]interface{}, 1000)
489+
for j := range values {
490+
valuePtrs[j] = &values[j]
491+
}
492+
493+
b.ReportAllocs()
494+
b.ResetTimer()
495+
496+
// Prepare a SELECT query to retrieve metadata
497+
stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1"))
498+
defer stmt.Close()
499+
500+
// Benchmark metadata retrieval
501+
for i := 0; i < b.N; i++ {
502+
rows := tb.checkRows(stmt.Query())
503+
504+
rows.Next()
505+
// Scan the row
506+
err := rows.Scan(valuePtrs...)
507+
tb.check(err)
508+
509+
rows.Close()
510+
}
511+
})
512+
}

connection.go

+33-24
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,22 @@ import (
2424
)
2525

2626
type mysqlConn struct {
27-
buf buffer
28-
netConn net.Conn
29-
rawConn net.Conn // underlying connection when netConn is TLS connection.
30-
result mysqlResult // managed by clearResult() and handleOkPacket().
31-
compIO *compIO
32-
cfg *Config
33-
connector *connector
34-
maxAllowedPacket int
35-
maxWriteSize int
36-
flags clientFlag
37-
status statusFlag
38-
sequence uint8
39-
compressSequence uint8
40-
parseTime bool
41-
compress bool
27+
buf buffer
28+
netConn net.Conn
29+
rawConn net.Conn // underlying connection when netConn is TLS connection.
30+
result mysqlResult // managed by clearResult() and handleOkPacket().
31+
compIO *compIO
32+
cfg *Config
33+
connector *connector
34+
maxAllowedPacket int
35+
maxWriteSize int
36+
clientCapabilities capabilityFlag
37+
clientExtCapabilities extendedCapabilityFlag
38+
status statusFlag
39+
sequence uint8
40+
compressSequence uint8
41+
parseTime bool
42+
compress bool
4243

4344
// for context support (Go 1.8+)
4445
watching bool
@@ -223,13 +224,21 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
223224
columnCount, err := stmt.readPrepareResultPacket()
224225
if err == nil {
225226
if stmt.paramCount > 0 {
226-
if err = mc.readUntilEOF(); err != nil {
227+
if err = mc.skipColumns(stmt.paramCount); err != nil {
227228
return nil, err
228229
}
229230
}
230231

231232
if columnCount > 0 {
232-
err = mc.readUntilEOF()
233+
if mc.clientExtCapabilities&clientCacheMetadata != 0 {
234+
if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil {
235+
return nil, err
236+
}
237+
} else {
238+
if err = mc.skipColumns(int(columnCount)); err != nil {
239+
return nil, err
240+
}
241+
}
233242
}
234243
}
235244

@@ -370,19 +379,19 @@ func (mc *mysqlConn) exec(query string) error {
370379
}
371380

372381
// Read Result
373-
resLen, err := handleOk.readResultSetHeaderPacket()
382+
resLen, _, err := handleOk.readResultSetHeaderPacket()
374383
if err != nil {
375384
return err
376385
}
377386

378387
if resLen > 0 {
379388
// columns
380-
if err := mc.readUntilEOF(); err != nil {
389+
if err := mc.skipColumns(resLen); err != nil {
381390
return err
382391
}
383392

384393
// rows
385-
if err := mc.readUntilEOF(); err != nil {
394+
if err := mc.skipResultSetRows(); err != nil {
386395
return err
387396
}
388397
}
@@ -419,7 +428,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
419428

420429
// Read Result
421430
var resLen int
422-
resLen, err = handleOk.readResultSetHeaderPacket()
431+
resLen, _, err = handleOk.readResultSetHeaderPacket()
423432
if err != nil {
424433
return nil, err
425434
}
@@ -453,22 +462,22 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
453462
}
454463

455464
// Read Result
456-
resLen, err := handleOk.readResultSetHeaderPacket()
465+
resLen, _, err := handleOk.readResultSetHeaderPacket()
457466
if err == nil {
458467
rows := new(textRows)
459468
rows.mc = mc
460469
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
461470

462471
if resLen > 0 {
463472
// Columns
464-
if err := mc.readUntilEOF(); err != nil {
473+
if err := mc.skipColumns(resLen); err != nil {
465474
return nil, err
466475
}
467476
}
468477

469478
dest := make([]driver.Value, resLen)
470479
if err = rows.readRow(dest); err == nil {
471-
return dest[0].([]byte), mc.readUntilEOF()
480+
return dest[0].([]byte), mc.skipResultSetRows()
472481
}
473482
}
474483
return nil, err

connector.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,16 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
131131
mc.buf = newBuffer()
132132

133133
// Reading Handshake Initialization Packet
134-
authData, plugin, err := mc.readHandshakePacket()
134+
authData, serverCapabilities, serverExtendedCapabilities, plugin, err := mc.readHandshakePacket()
135135
if err != nil {
136136
mc.cleanup()
137137
return nil, err
138138
}
139139

140+
if mc.cfg.TLS != nil && serverCapabilities&clientSSL == 0 {
141+
return nil, fmt.Errorf("TLS is required, but server doesn't support it")
142+
}
143+
140144
if plugin == "" {
141145
plugin = defaultAuthPlugin
142146
}
@@ -153,7 +157,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
153157
return nil, err
154158
}
155159
}
156-
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
160+
if err = mc.writeHandshakeResponsePacket(authResp, serverCapabilities, serverExtendedCapabilities, plugin); err != nil {
157161
mc.cleanup()
158162
return nil, err
159163
}
@@ -167,7 +171,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
167171
return nil, err
168172
}
169173

170-
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
174+
if mc.cfg.compress && mc.clientCapabilities&clientCompress > 0 {
171175
mc.compress = true
172176
mc.compIO = newCompIO(mc)
173177
}

const.go

+14-2
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ const (
4343
)
4444

4545
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
46-
type clientFlag uint32
46+
type capabilityFlag uint32
4747

4848
const (
49-
clientLongPassword clientFlag = 1 << iota
49+
clientMySQL capabilityFlag = 1 << iota
5050
clientFoundRows
5151
clientLongFlag
5252
clientConnectWithDB
@@ -73,6 +73,18 @@ const (
7373
clientDeprecateEOF
7474
)
7575

76+
// https://mariadb.com/kb/en/connection/#capabilities
77+
type extendedCapabilityFlag uint32
78+
79+
const (
80+
progressIndicator extendedCapabilityFlag = 1 << iota
81+
clientComMulti
82+
clientStmtBulkOperations
83+
clientExtendedMetadata
84+
clientCacheMetadata
85+
clientUnitBulkResult
86+
)
87+
7688
const (
7789
comQuit byte = iota + 1
7890
comInitDB

0 commit comments

Comments
 (0)