Skip to content

Placeholder interpolation #309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Feb 14, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e35fa00
Implement placeholder substitution.
methane Dec 31, 2014
c8c9bb1
Query() uses client-side placeholder substitution.
methane Dec 31, 2014
cac6129
Don't send text query larger than maxPacketAllowed
methane Jan 20, 2015
b7c2c47
Add substitutePlaceholder option to DSN
methane Jan 20, 2015
f3b82fd
Merge remote-tracking branch 'upstream/pr/297'
arvenil Jan 31, 2015
058ce87
Move escape funcs to utils.go, export them, add references to mysql s…
arvenil Feb 1, 2015
42956fa
Add tests for escaping functions
arvenil Feb 1, 2015
e6bf23a
Add basic SQL injection tests, including NO_BACKSLASH_ESCAPES sql_mode
arvenil Feb 1, 2015
b473259
Test if inserted data is correctly retrieved after being escaped
arvenil Feb 7, 2015
42a1efd
Don't stop test on MySQLWarnings
methane Feb 8, 2015
3c8fa90
substitutePlaceholder -> interpolateParams
methane Feb 8, 2015
6c8484b
Add interpolateParams document to README
methane Feb 8, 2015
04866ee
Fix nits pointed in pull request.
methane Feb 8, 2015
dd7b87c
Add benchmark for interpolateParams()
methane Feb 8, 2015
9faabe5
Don't write microseconds when Time.Nanosecond() == 0
methane Feb 8, 2015
468b9e5
Fix benchmark
methane Feb 8, 2015
0297315
Reduce allocs in interpolateParams.
methane Feb 8, 2015
0b75396
Inline datetime formatting
methane Feb 8, 2015
9f84dfb
Remove one more allocation
methane Feb 8, 2015
8826242
More acculate estimation of upper bound
methane Feb 8, 2015
916a1f2
escapeString -> escapeBackslash
methane Feb 9, 2015
88aeb98
append string... to []byte without cast.
methane Feb 10, 2015
43536c7
Specialize escape functions for string
methane Feb 10, 2015
0c7ae46
test for escapeString*
methane Feb 10, 2015
c285e39
Use digits10 and digits01 to format datetime.
methane Feb 10, 2015
fcea447
Round under microsecond
methane Feb 10, 2015
bfbe6c5
travis: Drop Go 1.1 and add Go 1.4
methane Feb 11, 2015
d65f96a
Fix typo
methane Feb 12, 2015
e11c825
Inlining mysqlConn.escapeBytes and mysqlConn.escapeString
methane Feb 12, 2015
b4f0315
Bit detailed info about vulnerability when using multibyte encoding.
methane Feb 12, 2015
1fd0514
Add link to StackOverflow describe vulnerability using multibyte enco…
methane Feb 12, 2015
20b75cd
Fix comment
methane Feb 12, 2015
e517683
Allow interpolateParams only with ascii, latin1 and utf8 collations
methane Feb 12, 2015
0f22bc2
extract function to reserve buffer
methane Feb 12, 2015
52a5860
Fix missing db.Close()
methane Feb 12, 2015
2a634df
Fix sentence in interpolateParams document.
methane Feb 12, 2015
90cb6c3
Use blacklist to avoid vulnerability with interpolation
methane Feb 12, 2015
9437b61
Adding myself to AUTHORS (however, 99% work done by @methane ;))
arvenil Feb 13, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
language: go
go:
- 1.1
- 1.2
- 1.3
- 1.4
- tip

before_script:
Expand Down
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ INADA Naoki <songofacandy at gmail.com>
James Harr <james.harr at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Julien Schmidt <go-sql-driver at julienschmidt.com>
Kamil Dziedzic <kamil at klecza.pl>
Leonardo YongUk Kim <dalinaum at gmail.com>
Lucas Liu <extrafliu at gmail.com>
Luke Scott <luke at webconnex.com>
Expand Down
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ SELECT u.id FROM users as u

will return `u.id` instead of just `id` if `columnsWithAlias=true`.

##### `interpolateParams`

```
Type: bool
Valid Values: true, false
Default: false
```

If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`.

NOTE: *This may introduce a SQL injection vulnerability when connection encoding is multibyte encoding except for UTF-8 (e.g. CP932)!*
(See http://stackoverflow.com/a/12118602/3430118)

##### `loc`

```
Expand Down
39 changes: 38 additions & 1 deletion benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ package mysql
import (
"bytes"
"database/sql"
"database/sql/driver"
"math"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)

type TB testing.B
Expand Down Expand Up @@ -45,7 +48,11 @@ func initDB(b *testing.B, queries ...string) *sql.DB {
db := tb.checkDB(sql.Open("mysql", dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
b.Fatalf("Error on %q: %v", query, err)
if w, ok := err.(MySQLWarnings); ok {
b.Logf("Warning on %q: %v", query, w)
} else {
b.Fatalf("Error on %q: %v", query, err)
}
}
}
return db
Expand Down Expand Up @@ -206,3 +213,33 @@ func BenchmarkRoundtripBin(b *testing.B) {
rows.Close()
}
}

func BenchmarkInterpolation(b *testing.B) {
mc := &mysqlConn{
cfg: &config{
interpolateParams: true,
loc: time.UTC,
},
maxPacketAllowed: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
}

args := []driver.Value{
int64(42424242),
float64(math.Pi),
false,
time.Unix(1423411542, 807015000),
[]byte("bytes containing special chars ' \" \a \x00"),
"string containing special chars ' \" \a \x00",
}
q := "SELECT ?, ?, ?, ?, ?, ?"

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := mc.interpolateParams(q, args)
if err != nil {
b.Fatal(err)
}
}
}
14 changes: 14 additions & 0 deletions collations.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,17 @@ var collations = map[string]byte{
"utf8mb4_unicode_520_ci": 246,
"utf8mb4_vietnamese_ci": 247,
}

// A blacklist of collations which is unsafe to interpolate parameters.
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
var unsafeCollations = map[byte]bool{
1: true, // big5_chinese_ci
13: true, // sjis_japanese_ci
28: true, // gbk_chinese_ci
84: true, // big5_bin
86: true, // gb2312_bin
87: true, // gbk_bin
88: true, // sjis_bin
95: true, // cp932_japanese_ci
96: true, // cp932_bin
}
226 changes: 191 additions & 35 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"database/sql/driver"
"errors"
"net"
"strconv"
"strings"
"time"
)
Expand All @@ -26,6 +27,7 @@ type mysqlConn struct {
maxPacketAllowed int
maxWriteSize int
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
strict bool
Expand All @@ -46,6 +48,7 @@ type config struct {
allowOldPasswords bool
clientFoundRows bool
columnsWithAlias bool
interpolateParams bool
}

// Handles parameters set in DSN after the connection is established
Expand Down Expand Up @@ -162,28 +165,174 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
return stmt, err
}

// estimateParamLength calculates upper bound of string length from types.
func estimateParamLength(args []driver.Value) (int, bool) {
l := 0
for _, a := range args {
switch v := a.(type) {
case int64, float64:
// 24 (-1.7976931348623157e+308) may be upper bound. But I'm not sure.
l += 25
case bool:
l += 1 // 0 or 1
case time.Time:
l += 30 // '1234-12-23 12:34:56.777777'
case string:
l += len(v)*2 + 2
case []byte:
l += len(v)*2 + 2
default:
return 0, false
}
}
return l, true
}

func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
estimated, ok := estimateParamLength(args)
if !ok {
return "", driver.ErrSkip
}
estimated += len(query)

buf := make([]byte, 0, estimated)
argPos := 0

for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q

arg := args[argPos]
argPos++

if arg == nil {
buf = append(buf, "NULL"...)
continue
}

switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
v := v.In(mc.cfg.loc)
v = v.Add(time.Nanosecond * 500) // To round under microsecond
year := v.Year()
year100 := year / 100
year1 := year % 100
month := v.Month()
day := v.Day()
hour := v.Hour()
minute := v.Minute()
second := v.Second()
micro := v.Nanosecond() / 1000

buf = append(buf, []byte{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we benchmarked and checked https://github.com/go-sql-driver/mysql/blob/master/utils.go#L535 pretty extensively. Maybe you can get inspiration from there, especially the const digits10 and digits01 part without division and modulus was faster than this approach. Maybe the code can be unified (I have doubts).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've did.
The effect is in error span.

'\'',
digits10[year100], digits01[year100],
digits10[year1], digits01[year1],
'-',
digits10[month], digits01[month],
'-',
digits10[day], digits01[day],
' ',
digits10[hour], digits01[hour],
':',
digits10[minute], digits01[minute],
':',
digits10[second], digits01[second],
}...)

if micro != 0 {
micro10000 := micro / 10000
micro100 := micro / 100 % 100
micro1 := micro % 100
buf = append(buf, []byte{
'.',
digits10[micro10000], digits01[micro10000],
digits10[micro100], digits01[micro100],
digits10[micro1], digits01[micro1],
}...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MySQL 5.6 rounds under 1 micro second.
Should I add 500 nanosecond before formatting?

}
buf = append(buf, '\'')
}
case []byte:
if v == nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed, handled in L191

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Required since nil and []byte(nil) is different.

buf = append(buf, "NULL"...)
} else {
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
}
case string:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When []byte and string are essentially the same code, can they be in one case []byte, string:? Can escapeBytes be inlined?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I don't know that I can use case []byte, string: in type switch switch v := arg.(type).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't...
./connection.go:214: cannot convert v (type driver.Value) to type []byte: need type assertion

buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}

if len(buf)+4 > mc.maxPacketAllowed {
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
mc.affectedRows = 0
mc.insertId = 0

err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
if len(args) != 0 {
if !mc.cfg.interpolateParams {
return nil, driver.ErrSkip
}
return nil, err
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
args = nil
}
mc.affectedRows = 0
mc.insertId = 0

// with args, must use prepared stmt
return nil, driver.ErrSkip

err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, err
}

// Internal function to execute commands
Expand Down Expand Up @@ -212,31 +361,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if len(args) != 0 {
if !mc.cfg.interpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
args = nil
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc

if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
}
// Columns
rows.columns, err = mc.readColumns(resLen)
return rows, err
rows := new(textRows)
rows.mc = mc

if resLen == 0 {
// no columns, no more data
return emptyRows{}, nil
}
// Columns
rows.columns, err = mc.readColumns(resLen)
return rows, err
}
return nil, err
}

// with args, must use prepared stmt
return nil, driver.ErrSkip
return nil, err
}

// Gets the value of the given MySQL System Variable
Expand Down
Loading