@@ -13,6 +13,7 @@ import (
13
13
"database/sql/driver"
14
14
"errors"
15
15
"net"
16
+ "strconv"
16
17
"strings"
17
18
"time"
18
19
)
@@ -26,6 +27,7 @@ type mysqlConn struct {
26
27
maxPacketAllowed int
27
28
maxWriteSize int
28
29
flags clientFlag
30
+ status statusFlag
29
31
sequence uint8
30
32
parseTime bool
31
33
strict bool
@@ -46,6 +48,7 @@ type config struct {
46
48
allowOldPasswords bool
47
49
clientFoundRows bool
48
50
columnsWithAlias bool
51
+ interpolateParams bool
49
52
}
50
53
51
54
// Handles parameters set in DSN after the connection is established
@@ -162,28 +165,174 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
162
165
return stmt , err
163
166
}
164
167
168
+ // estimateParamLength calculates upper bound of string length from types.
169
+ func estimateParamLength (args []driver.Value ) (int , bool ) {
170
+ l := 0
171
+ for _ , a := range args {
172
+ switch v := a .(type ) {
173
+ case int64 , float64 :
174
+ // 24 (-1.7976931348623157e+308) may be upper bound. But I'm not sure.
175
+ l += 25
176
+ case bool :
177
+ l += 1 // 0 or 1
178
+ case time.Time :
179
+ l += 30 // '1234-12-23 12:34:56.777777'
180
+ case string :
181
+ l += len (v )* 2 + 2
182
+ case []byte :
183
+ l += len (v )* 2 + 2
184
+ default :
185
+ return 0 , false
186
+ }
187
+ }
188
+ return l , true
189
+ }
190
+
191
+ func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
192
+ estimated , ok := estimateParamLength (args )
193
+ if ! ok {
194
+ return "" , driver .ErrSkip
195
+ }
196
+ estimated += len (query )
197
+
198
+ buf := make ([]byte , 0 , estimated )
199
+ argPos := 0
200
+
201
+ for i := 0 ; i < len (query ); i ++ {
202
+ q := strings .IndexByte (query [i :], '?' )
203
+ if q == - 1 {
204
+ buf = append (buf , query [i :]... )
205
+ break
206
+ }
207
+ buf = append (buf , query [i :i + q ]... )
208
+ i += q
209
+
210
+ arg := args [argPos ]
211
+ argPos ++
212
+
213
+ if arg == nil {
214
+ buf = append (buf , "NULL" ... )
215
+ continue
216
+ }
217
+
218
+ switch v := arg .(type ) {
219
+ case int64 :
220
+ buf = strconv .AppendInt (buf , v , 10 )
221
+ case float64 :
222
+ buf = strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
223
+ case bool :
224
+ if v {
225
+ buf = append (buf , '1' )
226
+ } else {
227
+ buf = append (buf , '0' )
228
+ }
229
+ case time.Time :
230
+ if v .IsZero () {
231
+ buf = append (buf , "'0000-00-00'" ... )
232
+ } else {
233
+ v := v .In (mc .cfg .loc )
234
+ v = v .Add (time .Nanosecond * 500 ) // To round under microsecond
235
+ year := v .Year ()
236
+ year100 := year / 100
237
+ year1 := year % 100
238
+ month := v .Month ()
239
+ day := v .Day ()
240
+ hour := v .Hour ()
241
+ minute := v .Minute ()
242
+ second := v .Second ()
243
+ micro := v .Nanosecond () / 1000
244
+
245
+ buf = append (buf , []byte {
246
+ '\'' ,
247
+ digits10 [year100 ], digits01 [year100 ],
248
+ digits10 [year1 ], digits01 [year1 ],
249
+ '-' ,
250
+ digits10 [month ], digits01 [month ],
251
+ '-' ,
252
+ digits10 [day ], digits01 [day ],
253
+ ' ' ,
254
+ digits10 [hour ], digits01 [hour ],
255
+ ':' ,
256
+ digits10 [minute ], digits01 [minute ],
257
+ ':' ,
258
+ digits10 [second ], digits01 [second ],
259
+ }... )
260
+
261
+ if micro != 0 {
262
+ micro10000 := micro / 10000
263
+ micro100 := micro / 100 % 100
264
+ micro1 := micro % 100
265
+ buf = append (buf , []byte {
266
+ '.' ,
267
+ digits10 [micro10000 ], digits01 [micro10000 ],
268
+ digits10 [micro100 ], digits01 [micro100 ],
269
+ digits10 [micro1 ], digits01 [micro1 ],
270
+ }... )
271
+ }
272
+ buf = append (buf , '\'' )
273
+ }
274
+ case []byte :
275
+ if v == nil {
276
+ buf = append (buf , "NULL" ... )
277
+ } else {
278
+ buf = append (buf , '\'' )
279
+ if mc .status & statusNoBackslashEscapes == 0 {
280
+ buf = escapeBytesBackslash (buf , v )
281
+ } else {
282
+ buf = escapeBytesQuotes (buf , v )
283
+ }
284
+ buf = append (buf , '\'' )
285
+ }
286
+ case string :
287
+ buf = append (buf , '\'' )
288
+ if mc .status & statusNoBackslashEscapes == 0 {
289
+ buf = escapeStringBackslash (buf , v )
290
+ } else {
291
+ buf = escapeStringQuotes (buf , v )
292
+ }
293
+ buf = append (buf , '\'' )
294
+ default :
295
+ return "" , driver .ErrSkip
296
+ }
297
+
298
+ if len (buf )+ 4 > mc .maxPacketAllowed {
299
+ return "" , driver .ErrSkip
300
+ }
301
+ }
302
+ if argPos != len (args ) {
303
+ return "" , driver .ErrSkip
304
+ }
305
+ return string (buf ), nil
306
+ }
307
+
165
308
func (mc * mysqlConn ) Exec (query string , args []driver.Value ) (driver.Result , error ) {
166
309
if mc .netConn == nil {
167
310
errLog .Print (ErrInvalidConn )
168
311
return nil , driver .ErrBadConn
169
312
}
170
- if len (args ) == 0 { // no args, fastpath
171
- mc .affectedRows = 0
172
- mc .insertId = 0
173
-
174
- err := mc .exec (query )
175
- if err == nil {
176
- return & mysqlResult {
177
- affectedRows : int64 (mc .affectedRows ),
178
- insertId : int64 (mc .insertId ),
179
- }, err
313
+ if len (args ) != 0 {
314
+ if ! mc .cfg .interpolateParams {
315
+ return nil , driver .ErrSkip
180
316
}
181
- return nil , err
317
+ // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
318
+ prepared , err := mc .interpolateParams (query , args )
319
+ if err != nil {
320
+ return nil , err
321
+ }
322
+ query = prepared
323
+ args = nil
182
324
}
325
+ mc .affectedRows = 0
326
+ mc .insertId = 0
183
327
184
- // with args, must use prepared stmt
185
- return nil , driver .ErrSkip
186
-
328
+ err := mc .exec (query )
329
+ if err == nil {
330
+ return & mysqlResult {
331
+ affectedRows : int64 (mc .affectedRows ),
332
+ insertId : int64 (mc .insertId ),
333
+ }, err
334
+ }
335
+ return nil , err
187
336
}
188
337
189
338
// Internal function to execute commands
@@ -212,31 +361,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
212
361
errLog .Print (ErrInvalidConn )
213
362
return nil , driver .ErrBadConn
214
363
}
215
- if len (args ) == 0 { // no args, fastpath
216
- // Send command
217
- err := mc .writeCommandPacketStr (comQuery , query )
364
+ if len (args ) != 0 {
365
+ if ! mc .cfg .interpolateParams {
366
+ return nil , driver .ErrSkip
367
+ }
368
+ // try client-side prepare to reduce roundtrip
369
+ prepared , err := mc .interpolateParams (query , args )
370
+ if err != nil {
371
+ return nil , err
372
+ }
373
+ query = prepared
374
+ args = nil
375
+ }
376
+ // Send command
377
+ err := mc .writeCommandPacketStr (comQuery , query )
378
+ if err == nil {
379
+ // Read Result
380
+ var resLen int
381
+ resLen , err = mc .readResultSetHeaderPacket ()
218
382
if err == nil {
219
- // Read Result
220
- var resLen int
221
- resLen , err = mc .readResultSetHeaderPacket ()
222
- if err == nil {
223
- rows := new (textRows )
224
- rows .mc = mc
225
-
226
- if resLen == 0 {
227
- // no columns, no more data
228
- return emptyRows {}, nil
229
- }
230
- // Columns
231
- rows .columns , err = mc .readColumns (resLen )
232
- return rows , err
383
+ rows := new (textRows )
384
+ rows .mc = mc
385
+
386
+ if resLen == 0 {
387
+ // no columns, no more data
388
+ return emptyRows {}, nil
233
389
}
390
+ // Columns
391
+ rows .columns , err = mc .readColumns (resLen )
392
+ return rows , err
234
393
}
235
- return nil , err
236
394
}
237
-
238
- // with args, must use prepared stmt
239
- return nil , driver .ErrSkip
395
+ return nil , err
240
396
}
241
397
242
398
// Gets the value of the given MySQL System Variable
0 commit comments