Skip to content

Commit 200c80b

Browse files
committed
Merge pull request go-sql-driver#309 from arvenil/placeholder
Placeholder interpolation
2 parents 27633f0 + 9437b61 commit 200c80b

11 files changed

+631
-58
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
sudo: false
22
language: go
33
go:
4-
- 1.1
54
- 1.2
65
- 1.3
6+
- 1.4
77
- tip
88

99
before_script:

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ INADA Naoki <songofacandy at gmail.com>
2424
James Harr <james.harr at gmail.com>
2525
Jian Zhen <zhenjl at gmail.com>
2626
Julien Schmidt <go-sql-driver at julienschmidt.com>
27+
Kamil Dziedzic <kamil at klecza.pl>
2728
Leonardo YongUk Kim <dalinaum at gmail.com>
2829
Lucas Liu <extrafliu at gmail.com>
2930
Luke Scott <luke at webconnex.com>

README.md

+13
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,19 @@ SELECT u.id FROM users as u
182182

183183
will return `u.id` instead of just `id` if `columnsWithAlias=true`.
184184

185+
##### `interpolateParams`
186+
187+
```
188+
Type: bool
189+
Valid Values: true, false
190+
Default: false
191+
```
192+
193+
If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`.
194+
195+
NOTE: *This may introduce a SQL injection vulnerability when connection encoding is multibyte encoding except for UTF-8 (e.g. CP932)!*
196+
(See http://stackoverflow.com/a/12118602/3430118)
197+
185198
##### `loc`
186199

187200
```

benchmark_test.go

+38-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ package mysql
1111
import (
1212
"bytes"
1313
"database/sql"
14+
"database/sql/driver"
15+
"math"
1416
"strings"
1517
"sync"
1618
"sync/atomic"
1719
"testing"
20+
"time"
1821
)
1922

2023
type TB testing.B
@@ -45,7 +48,11 @@ func initDB(b *testing.B, queries ...string) *sql.DB {
4548
db := tb.checkDB(sql.Open("mysql", dsn))
4649
for _, query := range queries {
4750
if _, err := db.Exec(query); err != nil {
48-
b.Fatalf("Error on %q: %v", query, err)
51+
if w, ok := err.(MySQLWarnings); ok {
52+
b.Logf("Warning on %q: %v", query, w)
53+
} else {
54+
b.Fatalf("Error on %q: %v", query, err)
55+
}
4956
}
5057
}
5158
return db
@@ -206,3 +213,33 @@ func BenchmarkRoundtripBin(b *testing.B) {
206213
rows.Close()
207214
}
208215
}
216+
217+
func BenchmarkInterpolation(b *testing.B) {
218+
mc := &mysqlConn{
219+
cfg: &config{
220+
interpolateParams: true,
221+
loc: time.UTC,
222+
},
223+
maxPacketAllowed: maxPacketSize,
224+
maxWriteSize: maxPacketSize - 1,
225+
}
226+
227+
args := []driver.Value{
228+
int64(42424242),
229+
float64(math.Pi),
230+
false,
231+
time.Unix(1423411542, 807015000),
232+
[]byte("bytes containing special chars ' \" \a \x00"),
233+
"string containing special chars ' \" \a \x00",
234+
}
235+
q := "SELECT ?, ?, ?, ?, ?, ?"
236+
237+
b.ReportAllocs()
238+
b.ResetTimer()
239+
for i := 0; i < b.N; i++ {
240+
_, err := mc.interpolateParams(q, args)
241+
if err != nil {
242+
b.Fatal(err)
243+
}
244+
}
245+
}

collations.go

+14
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,17 @@ var collations = map[string]byte{
234234
"utf8mb4_unicode_520_ci": 246,
235235
"utf8mb4_vietnamese_ci": 247,
236236
}
237+
238+
// A blacklist of collations which is unsafe to interpolate parameters.
239+
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
240+
var unsafeCollations = map[byte]bool{
241+
1: true, // big5_chinese_ci
242+
13: true, // sjis_japanese_ci
243+
28: true, // gbk_chinese_ci
244+
84: true, // big5_bin
245+
86: true, // gb2312_bin
246+
87: true, // gbk_bin
247+
88: true, // sjis_bin
248+
95: true, // cp932_japanese_ci
249+
96: true, // cp932_bin
250+
}

connection.go

+191-35
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"database/sql/driver"
1414
"errors"
1515
"net"
16+
"strconv"
1617
"strings"
1718
"time"
1819
)
@@ -26,6 +27,7 @@ type mysqlConn struct {
2627
maxPacketAllowed int
2728
maxWriteSize int
2829
flags clientFlag
30+
status statusFlag
2931
sequence uint8
3032
parseTime bool
3133
strict bool
@@ -46,6 +48,7 @@ type config struct {
4648
allowOldPasswords bool
4749
clientFoundRows bool
4850
columnsWithAlias bool
51+
interpolateParams bool
4952
}
5053

5154
// Handles parameters set in DSN after the connection is established
@@ -162,28 +165,174 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
162165
return stmt, err
163166
}
164167

168+
// estimateParamLength calculates upper bound of string length from types.
169+
func estimateParamLength(args []driver.Value) (int, bool) {
170+
l := 0
171+
for _, a := range args {
172+
switch v := a.(type) {
173+
case int64, float64:
174+
// 24 (-1.7976931348623157e+308) may be upper bound. But I'm not sure.
175+
l += 25
176+
case bool:
177+
l += 1 // 0 or 1
178+
case time.Time:
179+
l += 30 // '1234-12-23 12:34:56.777777'
180+
case string:
181+
l += len(v)*2 + 2
182+
case []byte:
183+
l += len(v)*2 + 2
184+
default:
185+
return 0, false
186+
}
187+
}
188+
return l, true
189+
}
190+
191+
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
192+
estimated, ok := estimateParamLength(args)
193+
if !ok {
194+
return "", driver.ErrSkip
195+
}
196+
estimated += len(query)
197+
198+
buf := make([]byte, 0, estimated)
199+
argPos := 0
200+
201+
for i := 0; i < len(query); i++ {
202+
q := strings.IndexByte(query[i:], '?')
203+
if q == -1 {
204+
buf = append(buf, query[i:]...)
205+
break
206+
}
207+
buf = append(buf, query[i:i+q]...)
208+
i += q
209+
210+
arg := args[argPos]
211+
argPos++
212+
213+
if arg == nil {
214+
buf = append(buf, "NULL"...)
215+
continue
216+
}
217+
218+
switch v := arg.(type) {
219+
case int64:
220+
buf = strconv.AppendInt(buf, v, 10)
221+
case float64:
222+
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
223+
case bool:
224+
if v {
225+
buf = append(buf, '1')
226+
} else {
227+
buf = append(buf, '0')
228+
}
229+
case time.Time:
230+
if v.IsZero() {
231+
buf = append(buf, "'0000-00-00'"...)
232+
} else {
233+
v := v.In(mc.cfg.loc)
234+
v = v.Add(time.Nanosecond * 500) // To round under microsecond
235+
year := v.Year()
236+
year100 := year / 100
237+
year1 := year % 100
238+
month := v.Month()
239+
day := v.Day()
240+
hour := v.Hour()
241+
minute := v.Minute()
242+
second := v.Second()
243+
micro := v.Nanosecond() / 1000
244+
245+
buf = append(buf, []byte{
246+
'\'',
247+
digits10[year100], digits01[year100],
248+
digits10[year1], digits01[year1],
249+
'-',
250+
digits10[month], digits01[month],
251+
'-',
252+
digits10[day], digits01[day],
253+
' ',
254+
digits10[hour], digits01[hour],
255+
':',
256+
digits10[minute], digits01[minute],
257+
':',
258+
digits10[second], digits01[second],
259+
}...)
260+
261+
if micro != 0 {
262+
micro10000 := micro / 10000
263+
micro100 := micro / 100 % 100
264+
micro1 := micro % 100
265+
buf = append(buf, []byte{
266+
'.',
267+
digits10[micro10000], digits01[micro10000],
268+
digits10[micro100], digits01[micro100],
269+
digits10[micro1], digits01[micro1],
270+
}...)
271+
}
272+
buf = append(buf, '\'')
273+
}
274+
case []byte:
275+
if v == nil {
276+
buf = append(buf, "NULL"...)
277+
} else {
278+
buf = append(buf, '\'')
279+
if mc.status&statusNoBackslashEscapes == 0 {
280+
buf = escapeBytesBackslash(buf, v)
281+
} else {
282+
buf = escapeBytesQuotes(buf, v)
283+
}
284+
buf = append(buf, '\'')
285+
}
286+
case string:
287+
buf = append(buf, '\'')
288+
if mc.status&statusNoBackslashEscapes == 0 {
289+
buf = escapeStringBackslash(buf, v)
290+
} else {
291+
buf = escapeStringQuotes(buf, v)
292+
}
293+
buf = append(buf, '\'')
294+
default:
295+
return "", driver.ErrSkip
296+
}
297+
298+
if len(buf)+4 > mc.maxPacketAllowed {
299+
return "", driver.ErrSkip
300+
}
301+
}
302+
if argPos != len(args) {
303+
return "", driver.ErrSkip
304+
}
305+
return string(buf), nil
306+
}
307+
165308
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
166309
if mc.netConn == nil {
167310
errLog.Print(ErrInvalidConn)
168311
return nil, driver.ErrBadConn
169312
}
170-
if len(args) == 0 { // no args, fastpath
171-
mc.affectedRows = 0
172-
mc.insertId = 0
173-
174-
err := mc.exec(query)
175-
if err == nil {
176-
return &mysqlResult{
177-
affectedRows: int64(mc.affectedRows),
178-
insertId: int64(mc.insertId),
179-
}, err
313+
if len(args) != 0 {
314+
if !mc.cfg.interpolateParams {
315+
return nil, driver.ErrSkip
180316
}
181-
return nil, err
317+
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
318+
prepared, err := mc.interpolateParams(query, args)
319+
if err != nil {
320+
return nil, err
321+
}
322+
query = prepared
323+
args = nil
182324
}
325+
mc.affectedRows = 0
326+
mc.insertId = 0
183327

184-
// with args, must use prepared stmt
185-
return nil, driver.ErrSkip
186-
328+
err := mc.exec(query)
329+
if err == nil {
330+
return &mysqlResult{
331+
affectedRows: int64(mc.affectedRows),
332+
insertId: int64(mc.insertId),
333+
}, err
334+
}
335+
return nil, err
187336
}
188337

189338
// Internal function to execute commands
@@ -212,31 +361,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
212361
errLog.Print(ErrInvalidConn)
213362
return nil, driver.ErrBadConn
214363
}
215-
if len(args) == 0 { // no args, fastpath
216-
// Send command
217-
err := mc.writeCommandPacketStr(comQuery, query)
364+
if len(args) != 0 {
365+
if !mc.cfg.interpolateParams {
366+
return nil, driver.ErrSkip
367+
}
368+
// try client-side prepare to reduce roundtrip
369+
prepared, err := mc.interpolateParams(query, args)
370+
if err != nil {
371+
return nil, err
372+
}
373+
query = prepared
374+
args = nil
375+
}
376+
// Send command
377+
err := mc.writeCommandPacketStr(comQuery, query)
378+
if err == nil {
379+
// Read Result
380+
var resLen int
381+
resLen, err = mc.readResultSetHeaderPacket()
218382
if err == nil {
219-
// Read Result
220-
var resLen int
221-
resLen, err = mc.readResultSetHeaderPacket()
222-
if err == nil {
223-
rows := new(textRows)
224-
rows.mc = mc
225-
226-
if resLen == 0 {
227-
// no columns, no more data
228-
return emptyRows{}, nil
229-
}
230-
// Columns
231-
rows.columns, err = mc.readColumns(resLen)
232-
return rows, err
383+
rows := new(textRows)
384+
rows.mc = mc
385+
386+
if resLen == 0 {
387+
// no columns, no more data
388+
return emptyRows{}, nil
233389
}
390+
// Columns
391+
rows.columns, err = mc.readColumns(resLen)
392+
return rows, err
234393
}
235-
return nil, err
236394
}
237-
238-
// with args, must use prepared stmt
239-
return nil, driver.ErrSkip
395+
return nil, err
240396
}
241397

242398
// Gets the value of the given MySQL System Variable

0 commit comments

Comments
 (0)