Skip to content

Read optimization #1697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
Diego Dupin <diego.dupin at gmail.com>
Dirkjan Bussink <d.bussink at gmail.com>
DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com>
Expand Down
41 changes: 41 additions & 0 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,47 @@ func benchmarkQueryHelper(b *testing.B, compr bool) {
}
}

func BenchmarkSelect10000rows(b *testing.B) {
db := initDB(b, false)
defer db.Close()

// Check if we're using MariaDB
var version string
err := db.QueryRow("SELECT @@version").Scan(&version)
if err != nil {
b.Fatalf("Failed to get server version: %v", err)
}

if !strings.Contains(strings.ToLower(version), "mariadb") {
b.Skip("Skipping benchmark as it requires MariaDB sequence table")
return
}

b.StartTimer()
stmt, err := db.Prepare("SELECT * FROM seq_1_to_10000")
if err != nil {
b.Fatalf("Failed to prepare statement: %v", err)
}
defer stmt.Close()
for n := 0; n < b.N; n++ {
rows, err := stmt.Query()
if err != nil {
b.Fatalf("Failed to query 10000rows: %v", err)
}

var id int64
for rows.Next() {
err = rows.Scan(&id)
if err != nil {
rows.Close()
b.Fatalf("Failed to scan row: %v", err)
}
}
rows.Close()
}
b.StopTimer()
}

func BenchmarkExec(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
Expand Down
2 changes: 2 additions & 0 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []by
conn := new(mockConn)
conn.data = compressedPacket
mc.netConn = conn
mc.readNextFunc = mc.compIO.readNext
mc.readFunc = conn.Read

uncompressedPacket, err := mc.readPacket()
if err != nil {
Expand Down
14 changes: 3 additions & 11 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type mysqlConn struct {
compressSequence uint8
parseTime bool
compress bool
readFunc func([]byte) (int, error)
readNextFunc func(int, readerFunc) ([]byte, error)

// for context support (Go 1.8+)
watching bool
Expand All @@ -64,16 +66,6 @@ func (mc *mysqlConn) log(v ...any) {
mc.cfg.Logger.Print(v...)
}

func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) {
to := mc.cfg.ReadTimeout
if to > 0 {
if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil {
return 0, err
}
}
return mc.netConn.Read(b)
}

func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) {
to := mc.cfg.WriteTimeout
if to > 0 {
Expand Down Expand Up @@ -247,7 +239,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
// can not take the buffer. Something must be wrong with the connection
mc.cleanup()
// interpolateParams would be called before sending any query.
// So its safe to retry.
// So it's safe to retry.
return "", driver.ErrBadConn
}
buf = buf[:0]
Expand Down
7 changes: 6 additions & 1 deletion connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ import (
)

func TestInterpolateParams(t *testing.T) {
buf := newBuffer()
nc := &net.TCPConn{}
mc := &mysqlConn{
buf: newBuffer(),
buf: buf,
netConn: nc,
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
readNextFunc: buf.readNext,
readFunc: nc.Read,
}

q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
Expand Down
18 changes: 18 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"os"
"strconv"
"strings"
"time"
)

type connector struct {
Expand Down Expand Up @@ -130,6 +131,22 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {

mc.buf = newBuffer()

// setting readNext/read functions
mc.readNextFunc = mc.buf.readNext

// Initialize read function based on configuration
if mc.cfg.ReadTimeout > 0 {
mc.readFunc = func(b []byte) (int, error) {
deadline := time.Now().Add(mc.cfg.ReadTimeout)
if err := mc.netConn.SetReadDeadline(deadline); err != nil {
return 0, err
}
return mc.netConn.Read(b)
}
} else {
mc.readFunc = mc.netConn.Read
}

// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
Expand Down Expand Up @@ -170,6 +187,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
mc.compress = true
mc.compIO = newCompIO(mc)
mc.readNextFunc = mc.compIO.readNext
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
Expand Down
52 changes: 47 additions & 5 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1630,13 +1630,46 @@ func TestCollation(t *testing.T) {
}

runTests(t, tdsn, func(dbt *DBTest) {
// see https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation
// when character_set_collations is set for the charset, it overrides the default collation
// so we need to check if the default collation is overridden
forceExpected := expected
var defaultCollations string
err := dbt.db.QueryRow("SELECT @@character_set_collations").Scan(&defaultCollations)
if err == nil {
// Query succeeded, need to check if we should override expected collation
collationMap := make(map[string]string)
pairs := strings.Split(defaultCollations, ",")
for _, pair := range pairs {
parts := strings.Split(pair, "=")
if len(parts) == 2 {
collationMap[parts[0]] = parts[1]
}
}

// Get charset prefix from expected collation
parts := strings.Split(expected, "_")
if len(parts) > 0 {
charset := parts[0]
if newCollation, ok := collationMap[charset]; ok {
forceExpected = newCollation
}
}
}

var got string
if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil {
dbt.Fatal(err)
}

if got != expected {
dbt.Fatalf("expected connection collation %s but got %s", expected, got)
if forceExpected != expected {
if got != forceExpected {
dbt.Fatalf("expected forced connection collation %s but got %s", forceExpected, got)
}
} else {
dbt.Fatalf("expected connection collation %s but got %s", expected, got)
}
}
})
}
Expand Down Expand Up @@ -1685,16 +1718,16 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) {
}

func TestTimezoneConversion(t *testing.T) {
zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
zones := []string{"UTC", "America/New_York", "Asia/Hong_Kong", "Local"}

// Regression test for timezone handling
tzTest := func(dbt *DBTest) {
// Create table
dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")

// Insert local time into database (should be converted)
usCentral, _ := time.LoadLocation("US/Central")
reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral)
newYorkTz, _ := time.LoadLocation("America/New_York")
reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(newYorkTz)
dbt.mustExec("INSERT INTO test VALUE (?)", reftime)

// Retrieve time from DB
Expand All @@ -1713,7 +1746,7 @@ func TestTimezoneConversion(t *testing.T) {
// Check that dates match
if reftime.Unix() != dbTime.Unix() {
dbt.Errorf("times do not match.\n")
dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime)
dbt.Errorf(" Now(%v)=%v\n", newYorkTz, reftime)
dbt.Errorf(" Now(UTC)=%v\n", dbTime)
}
}
Expand Down Expand Up @@ -3541,6 +3574,15 @@ func TestConnectionAttributes(t *testing.T) {

dbt := &DBTest{t, db}

var varName string
var varValue string
err := dbt.db.QueryRow("SHOW VARIABLES LIKE 'performance_schema'").Scan(&varName, &varValue)
if err != nil {
t.Fatalf("error: %s", err.Error())
}
if varValue != "ON" {
t.Skipf("Performance schema is not enabled. skipping")
}
queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()"
rows := dbt.mustQuery(queryString)
defer rows.Close()
Expand Down
10 changes: 3 additions & 7 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte
invalidSequence := false

readNext := mc.buf.readNext
if mc.compress {
readNext = mc.compIO.readNext
}

for {
// read packet header
data, err := readNext(4, mc.readWithTimeout)
data, err := mc.readNextFunc(4, mc.readFunc)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
Expand Down Expand Up @@ -85,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
}

// read packet body [pktLen bytes]
data, err = readNext(pktLen, mc.readWithTimeout)
data, err = mc.readNextFunc(pktLen, mc.readFunc)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
Expand Down Expand Up @@ -369,6 +364,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
return err
}
mc.netConn = tlsConn
mc.readFunc = mc.netConn.Read
}

// User [null terminated string]
Expand Down
47 changes: 31 additions & 16 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,24 +97,30 @@ var _ net.Conn = new(mockConn)
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn)
connector := newConnector(NewConfig())
buf := newBuffer()
mc := &mysqlConn{
buf: newBuffer(),
buf: buf,
cfg: connector.cfg,
connector: connector,
netConn: conn,
closech: make(chan struct{}),
maxAllowedPacket: defaultMaxAllowedPacket,
sequence: sequence,
readNextFunc: buf.readNext,
readFunc: conn.Read,
}
return conn, mc
}

func TestReadPacketSingleByte(t *testing.T) {
conn := new(mockConn)
buf := newBuffer()
mc := &mysqlConn{
netConn: conn,
buf: newBuffer(),
cfg: NewConfig(),
netConn: conn,
buf: buf,
cfg: NewConfig(),
readNextFunc: buf.readNext,
readFunc: conn.Read,
}

conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
Expand Down Expand Up @@ -165,10 +171,13 @@ func TestReadPacketWrongSequenceID(t *testing.T) {

func TestReadPacketSplit(t *testing.T) {
conn := new(mockConn)
buf := newBuffer()
mc := &mysqlConn{
netConn: conn,
buf: newBuffer(),
cfg: NewConfig(),
netConn: conn,
buf: buf,
cfg: NewConfig(),
readNextFunc: buf.readNext,
readFunc: conn.Read,
}

data := make([]byte, maxPacketSize*2+4*3)
Expand Down Expand Up @@ -272,11 +281,14 @@ func TestReadPacketSplit(t *testing.T) {

func TestReadPacketFail(t *testing.T) {
conn := new(mockConn)
buf := newBuffer()
mc := &mysqlConn{
netConn: conn,
buf: newBuffer(),
closech: make(chan struct{}),
cfg: NewConfig(),
netConn: conn,
buf: buf,
closech: make(chan struct{}),
cfg: NewConfig(),
readNextFunc: buf.readNext,
readFunc: conn.Read,
}

// illegal empty (stand-alone) packet
Expand Down Expand Up @@ -317,12 +329,15 @@ func TestReadPacketFail(t *testing.T) {
// not-NUL terminated plugin_name in init packet
func TestRegression801(t *testing.T) {
conn := new(mockConn)
buf := newBuffer()
mc := &mysqlConn{
netConn: conn,
buf: newBuffer(),
cfg: new(Config),
sequence: 42,
closech: make(chan struct{}),
netConn: conn,
buf: buf,
cfg: new(Config),
sequence: 42,
closech: make(chan struct{}),
readNextFunc: buf.readNext,
readFunc: conn.Read,
}

conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
Expand Down