diff --git a/AUTHORS b/AUTHORS index 37bf11a2b..466ac86ab 100644 --- a/AUTHORS +++ b/AUTHORS @@ -45,6 +45,7 @@ Stan Putrya Stanley Gunawan Xiaobing Jiang Xiuming Chen +Zhenye Xie # Organizations diff --git a/README.md b/README.md index 388632816..c35ba2e92 100644 --- a/README.md +++ b/README.md @@ -299,6 +299,15 @@ Default: 0 I/O write timeout. The value must be a decimal number with an unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. +##### `maxAllowedPacket` +``` +Type: decimal number +Default: 0 +``` + +Max packet size allowed in bytes. Use `maxAllowedPacket=0` to automatically fetch the `max_allowed_packet` variable from server. + + ##### System Variables All other parameters are interpreted as system variables: diff --git a/benchmark_test.go b/benchmark_test.go index 8f721139b..7da833a2a 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -220,7 +220,7 @@ func BenchmarkInterpolation(b *testing.B) { InterpolateParams: true, Loc: time.UTC, }, - maxPacketAllowed: maxPacketSize, + maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, buf: newBuffer(nil), } diff --git a/connection.go b/connection.go index d37e36dea..d82c728f3 100644 --- a/connection.go +++ b/connection.go @@ -22,7 +22,7 @@ type mysqlConn struct { affectedRows uint64 insertId uint64 cfg *Config - maxPacketAllowed int + maxAllowedPacket int maxWriteSize int writeTimeout time.Duration flags clientFlag @@ -246,7 +246,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - if len(buf)+4 > mc.maxPacketAllowed { + if len(buf)+4 > mc.maxAllowedPacket { return "", driver.ErrSkip } } diff --git a/connection_test.go b/connection_test.go index 7111e4a6b..65325f101 100644 --- a/connection_test.go +++ b/connection_test.go @@ -16,7 +16,7 @@ import ( func TestInterpolateParams(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(nil), - maxPacketAllowed: maxPacketSize, + maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, @@ -36,7 +36,7 @@ func TestInterpolateParams(t *testing.T) { func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(nil), - maxPacketAllowed: maxPacketSize, + maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, @@ -53,7 +53,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { func TestInterpolateParamsPlaceholderInString(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(nil), - maxPacketAllowed: maxPacketSize, + maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, }, diff --git a/driver.go b/driver.go index 899f955fb..562ddeffb 100644 --- a/driver.go +++ b/driver.go @@ -50,7 +50,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { // New mysqlConn mc := &mysqlConn{ - maxPacketAllowed: maxPacketSize, + maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, } mc.cfg, err = ParseDSN(dsn) @@ -109,15 +109,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return nil, err } - // Get max allowed packet size - maxap, err := mc.getSystemVar("max_allowed_packet") - if err != nil { - mc.Close() - return nil, err + if mc.cfg.MaxAllowedPacket > 0 { + mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket + } else { + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxAllowedPacket = stringToInt(maxap) - 1 } - mc.maxPacketAllowed = stringToInt(maxap) - 1 - if mc.maxPacketAllowed < maxPacketSize { - mc.maxWriteSize = mc.maxPacketAllowed + if mc.maxAllowedPacket < maxPacketSize { + mc.maxWriteSize = mc.maxAllowedPacket } // Handle DSN Params diff --git a/dsn.go b/dsn.go index 73138bc57..896be9ef5 100644 --- a/dsn.go +++ b/dsn.go @@ -15,6 +15,7 @@ import ( "fmt" "net" "net/url" + "strconv" "strings" "time" ) @@ -28,19 +29,20 @@ var ( // Config is a configuration parsed from a DSN string type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - TLSConfig string // TLS configuration name - tls *tls.Config // TLS configuration - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -222,6 +224,17 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.WriteTimeout.String()) } + if cfg.MaxAllowedPacket > 0 { + if hasParam { + buf.WriteString("&maxAllowedPacket=") + } else { + hasParam = true + buf.WriteString("?maxAllowedPacket=") + } + buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket)) + + } + // other params if cfg.Params != nil { for param, value := range cfg.Params { @@ -496,7 +509,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } - + case "maxAllowedPacket": + cfg.MaxAllowedPacket, err = strconv.Atoi(value) + if err != nil { + return + } default: // lazy init if cfg.Params == nil { diff --git a/dsn_test.go b/dsn_test.go index e6f0f83b1..0693192ad 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -39,8 +39,8 @@ var testDSNs = []struct { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, TLSConfig: "skip-verify"}, }, { - "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true}, + "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local}, diff --git a/packets.go b/packets.go index 602539942..f06752b02 100644 --- a/packets.go +++ b/packets.go @@ -80,7 +80,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 - if pktLen > mc.maxPacketAllowed { + if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } @@ -786,7 +786,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { - maxLen := stmt.mc.maxPacketAllowed - 1 + maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen // After the header (bytes 0-3) follows before the data: @@ -977,7 +977,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = fieldTypeString paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -999,7 +999,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = fieldTypeString paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), )