Skip to content

[RFC] Placeholder substitution. #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 176 additions & 48 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"database/sql/driver"
"errors"
"net"
"strconv"
"strings"
"time"
)
Expand All @@ -26,25 +27,27 @@ type mysqlConn struct {
maxPacketAllowed int
maxWriteSize int
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
strict bool
}

type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
tls *tls.Config
timeout time.Duration
collation uint8
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
tls *tls.Config
timeout time.Duration
collation uint8
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
substitutePlaceholder bool
}

// Handles parameters set in DSN after the connection is established
Expand Down Expand Up @@ -161,28 +164,146 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
return stmt, err
}

func (mc *mysqlConn) escapeBytes(v []byte) string {
buf := make([]byte, len(v)*2+2)
buf[0] = '\''
pos := 1
if mc.status&statusNoBackslashEscapes == 0 {
for _, c := range v {
switch c {
case '\x00':
buf[pos] = '\\'
buf[pos+1] = '0'
pos += 2
case '\n':
buf[pos] = '\\'
buf[pos+1] = 'n'
pos += 2
case '\r':
buf[pos] = '\\'
buf[pos+1] = 'r'
pos += 2
case '\x1a':
buf[pos] = '\\'
buf[pos+1] = 'Z'
pos += 2
case '\'':
buf[pos] = '\\'
buf[pos+1] = '\''
pos += 2
case '"':
buf[pos] = '\\'
buf[pos+1] = '"'
pos += 2
case '\\':
buf[pos] = '\\'
buf[pos+1] = '\\'
pos += 2
default:
buf[pos] = c
pos += 1
}
}
} else {
for _, c := range v {
if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
}
buf[pos] = '\''
return string(buf[:pos+1])
}

func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) {
chunks := strings.Split(query, "?")
if len(chunks) != len(args)+1 {
return "", driver.ErrSkip
}

parts := make([]string, len(chunks)+len(args))
parts[0] = chunks[0]

for i, arg := range args {
pos := i*2 + 1
parts[pos+1] = chunks[i+1]
if arg == nil {
parts[pos] = "NULL"
continue
}
switch v := arg.(type) {
case int64:
parts[pos] = strconv.FormatInt(v, 10)
case float64:
parts[pos] = strconv.FormatFloat(v, 'f', -1, 64)
case bool:
if v {
parts[pos] = "1"
} else {
parts[pos] = "0"
}
case time.Time:
if v.IsZero() {
parts[pos] = "'0000-00-00'"
} else {
fmt := "'2006-01-02 15:04:05.999999'"
parts[pos] = v.In(mc.cfg.loc).Format(fmt)
}
case []byte:
if v == nil {
parts[pos] = "NULL"
} else {
parts[pos] = mc.escapeBytes(v)
}
case string:
parts[pos] = mc.escapeBytes([]byte(v))
default:
return "", driver.ErrSkip
}
}
pktSize := len(query) + 4 // 4 bytes for header.
for _, p := range parts {
pktSize += len(p)
}
if pktSize > mc.maxPacketAllowed {
return "", driver.ErrSkip
}
return strings.Join(parts, ""), nil
}

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
mc.affectedRows = 0
mc.insertId = 0

err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
if len(args) != 0 {
if !mc.cfg.substitutePlaceholder {
return nil, driver.ErrSkip
}
return nil, err
// try client-side prepare to reduce roundtrip
prepared, err := mc.buildQuery(query, args)
if err != nil {
return nil, err
}
query = prepared
args = nil
}
mc.affectedRows = 0
mc.insertId = 0

// with args, must use prepared stmt
return nil, driver.ErrSkip

err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, err
}

// Internal function to execute commands
Expand Down Expand Up @@ -211,31 +332,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if len(args) != 0 {
if !mc.cfg.substitutePlaceholder {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
prepared, err := mc.buildQuery(query, args)
if err != nil {
return nil, err
}
query = prepared
args = nil
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc

if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
}
// Columns
rows.columns, err = mc.readColumns(resLen)
return rows, err
rows := new(textRows)
rows.mc = mc

if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
}
// Columns
rows.columns, err = mc.readColumns(resLen)
return rows, err
}
return nil, err
}

// with args, must use prepared stmt
return nil, driver.ErrSkip
return nil, err
}

// Gets the value of the given MySQL System Variable
Expand Down
22 changes: 22 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,25 @@ const (
flagUnknown3
flagUnknown4
)

// http://dev.mysql.com/doc/internals/en/status-flags.html

type statusFlag uint16

const (
statusInTrans statusFlag = 1 << iota
statusInAutocommit
statusUnknown1
statusMoreResultsExists
statusNoGoodIndexUsed
statusNoIndexUsed
statusCursorExists
statusLastRowSent
statusDbDropped
statusNoBackslashEscapes
statusMetadataChanged
statusQueryWasSlow
statusPsOutParams
statusInTransReadonly
statusSessionStateChanged
)
9 changes: 9 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {

db.Exec("DROP TABLE IF EXISTS test")

dbp, err := sql.Open("mysql", dsn+"&substitutePlaceholder=true")
if err != nil {
t.Fatalf("Error connecting: %s", err.Error())
}
defer dbp.Close()

dbt := &DBTest{t, db}
dbtp := &DBTest{t, dbp}
for _, test := range tests {
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
test(dbtp)
dbtp.db.Exec("DROP TABLE IF EXISTS test")
}
}

Expand Down
1 change: 1 addition & 0 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])

// server_status [2 bytes]
mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8

// warning count [2 bytes]
if !mc.strict {
Expand Down
8 changes: 8 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ func parseDSNParams(cfg *config, params string) (err error) {
// cfg params
switch value := param[1]; param[0] {

// Enable client side placeholder substitution
case "substitutePlaceholder":
var isBool bool
cfg.substitutePlaceholder, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}

// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
Expand Down
24 changes: 12 additions & 12 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ var testDSNs = []struct {
out string
loc *time.Location
}{
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true substitutePlaceholder:false}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false substitutePlaceholder:false}", time.UTC},
}

func TestDSNParser(t *testing.T) {
Expand Down