@@ -126,17 +126,19 @@ func getQueryProcess(db *sql.DB, dbName, query string) (*dbProcess, error) {
126
126
return longProcess , err
127
127
}
128
128
129
- func killQuery (db * sql.DB , dbName , query string , cancel context.CancelFunc ) error {
129
+ var expectedKilledErr = fmt .Errorf ("process expected to be killed" )
130
+
131
+ func killQuery (db * sql.DB , dbName , query string , timeout time.Duration , cancel context.CancelFunc ) error {
130
132
process , err := getQueryProcess (db , dbName , query )
131
133
if err != nil {
132
134
return fmt .Errorf ("failed to get mysql process: %v" , err )
133
135
}
134
136
cancel ()
135
137
136
- end := time .Now ().Add (killTimeout )
138
+ end := time .Now ().Add (timeout )
137
139
for time .Now ().Before (end ) {
138
140
if checkProcessExists (dbName , process .ID , db ) {
139
- err = fmt . Errorf ( "process %d expected to be killed" , process . ID )
141
+ err = expectedKilledErr
140
142
time .Sleep (pollTimeout )
141
143
} else {
142
144
err = nil
@@ -173,7 +175,7 @@ func testCancel(dbt *DBTest, ctx context.Context, cancel context.CancelFunc, que
173
175
}()
174
176
175
177
// it is safe to not use timeouts here since they are inside the killQuery function
176
- err = killQuery (dbt .db , dbname , query , cancel )
178
+ err = killQuery (dbt .db , dbname , query , killTimeout , cancel )
177
179
if err != nil {
178
180
dbt .Error (err )
179
181
return
@@ -195,6 +197,62 @@ func testCancel(dbt *DBTest, ctx context.Context, cancel context.CancelFunc, que
195
197
tx .Commit ()
196
198
}
197
199
200
+ func testCancelNoKill (dbt * DBTest , ctx context.Context , cancel context.CancelFunc , query string , queryFunc func () error ) {
201
+ tx , err := dbt .db .BeginTx (context .Background (), nil )
202
+ if err != nil {
203
+ dbt .Fatal (err )
204
+ return
205
+ }
206
+
207
+ _ , err = tx .Exec ("LOCK TABLES test WRITE" )
208
+ if err != nil {
209
+ tx .Rollback ()
210
+ dbt .Fatal (err )
211
+ }
212
+
213
+ errChan := make (chan error )
214
+ go func () {
215
+ // This query will be canceled.
216
+ err = queryFunc ()
217
+ if err != nil && err != context .Canceled {
218
+ errLog .Print (err )
219
+ }
220
+ if err != context .Canceled && ctx .Err () != context .Canceled {
221
+ errChan <- fmt .Errorf ("expected context.Canceled, got %v" , err )
222
+ return
223
+ }
224
+ errChan <- nil
225
+ }()
226
+
227
+ // it is safe to not use timeouts here since they are inside the killQuery function
228
+ err = killQuery (dbt .db , dbname , query , 500 * time .Millisecond , cancel )
229
+ if err != expectedKilledErr {
230
+ if err == nil {
231
+ dbt .Errorf ("query kill expected to fail" )
232
+ } else {
233
+ dbt .Errorf (fmt .Sprintf ("unexpected error %s" , err ))
234
+ }
235
+ }
236
+
237
+ _ , err = tx .Exec ("UNLOCK TABLES" )
238
+ if err != nil {
239
+ tx .Rollback ()
240
+ dbt .Fatal (err )
241
+ }
242
+ tx .Commit ()
243
+
244
+ <- errChan
245
+ }
246
+
247
+ func getKillDSN () string {
248
+ cfg , err := ParseDSN (dsn )
249
+ if err != nil {
250
+ panic (err )
251
+ }
252
+ cfg .KillQueryOnTimeout = true
253
+ return cfg .FormatDSN ()
254
+ }
255
+
198
256
func TestMultiResultSet (t * testing.T ) {
199
257
type result struct {
200
258
values [][]int
@@ -385,12 +443,62 @@ func TestPingContext(t *testing.T) {
385
443
})
386
444
}
387
445
388
- func TestContextCancelExec (t * testing.T ) {
446
+ func TestContextCancelNoKill (t * testing.T ) {
389
447
runTests (t , dsn , func (dbt * DBTest ) {
390
448
dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
391
449
ctx , cancel := context .WithCancel (context .Background ())
392
450
exec := "INSERT INTO test VALUES(1)"
393
451
452
+ testCancelNoKill (dbt , ctx , cancel , exec , func () error {
453
+ _ , err := dbt .db .ExecContext (ctx , exec )
454
+ return err
455
+ })
456
+
457
+ // Check how many times the query is executed.
458
+ var v int
459
+ var err error
460
+ for i := 0 ; i != 3 ; i ++ {
461
+ err = nil
462
+ if err := dbt .db .QueryRow ("SELECT COUNT(*) FROM test" ).Scan (& v ); err != nil {
463
+ dbt .Fatalf ("%s" , err .Error ())
464
+ return
465
+ }
466
+ if v != 1 {
467
+ err = fmt .Errorf ("expected val to be 1, got %d" , v )
468
+ }
469
+
470
+ if err != nil {
471
+ time .Sleep (100 * time .Millisecond ) // wait while insert is executed after table lock released
472
+ }
473
+ }
474
+ if err != nil {
475
+ dbt .Error (err )
476
+ return
477
+ }
478
+
479
+ // Context is already canceled, so error should come before execution.
480
+ if _ , err := dbt .db .ExecContext (ctx , "INSERT INTO test VALUES (1)" ); err == nil {
481
+ dbt .Error ("expected error" )
482
+ } else if err .Error () != "context canceled" {
483
+ dbt .Fatalf ("unexpected error: %s" , err )
484
+ }
485
+
486
+ // The second insert query will fail, so the table has no changes.
487
+ if err := dbt .db .QueryRow ("SELECT COUNT(*) FROM test" ).Scan (& v ); err != nil {
488
+ dbt .Fatalf ("%s" , err .Error ())
489
+ }
490
+ if v != 1 {
491
+ dbt .Errorf ("expected val to be 1, got %d" , v )
492
+ }
493
+ })
494
+ }
495
+
496
+ func TestContextCancelExec (t * testing.T ) {
497
+ runTests (t , getKillDSN (), func (dbt * DBTest ) {
498
+ dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
499
+ ctx , cancel := context .WithCancel (context .Background ())
500
+ exec := "INSERT INTO test VALUES(1)"
501
+
394
502
testCancel (dbt , ctx , cancel , exec , func () error {
395
503
_ , err := dbt .db .ExecContext (ctx , exec )
396
504
return err
@@ -423,7 +531,7 @@ func TestContextCancelExec(t *testing.T) {
423
531
}
424
532
425
533
func TestContextCancelQuery (t * testing.T ) {
426
- runTests (t , dsn , func (dbt * DBTest ) {
534
+ runTests (t , getKillDSN () , func (dbt * DBTest ) {
427
535
dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
428
536
ctx , cancel := context .WithCancel (context .Background ())
429
537
query := "SELECT 1 FROM test"
@@ -501,7 +609,7 @@ func TestContextCancelPrepare(t *testing.T) {
501
609
}
502
610
503
611
func TestContextCancelStmtExec (t * testing.T ) {
504
- runTests (t , dsn , func (dbt * DBTest ) {
612
+ runTests (t , getKillDSN () , func (dbt * DBTest ) {
505
613
dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
506
614
ctx , cancel := context .WithCancel (context .Background ())
507
615
exec := "INSERT INTO test VALUES(1)"
@@ -528,7 +636,7 @@ func TestContextCancelStmtExec(t *testing.T) {
528
636
}
529
637
530
638
func TestContextCancelStmtQuery (t * testing.T ) {
531
- runTests (t , dsn , func (dbt * DBTest ) {
639
+ runTests (t , getKillDSN () , func (dbt * DBTest ) {
532
640
dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
533
641
ctx , cancel := context .WithCancel (context .Background ())
534
642
query := "SELECT 1 FROM test"
@@ -555,7 +663,7 @@ func TestContextCancelStmtQuery(t *testing.T) {
555
663
}
556
664
557
665
func TestContextCancelBegin (t * testing.T ) {
558
- runTests (t , dsn , func (dbt * DBTest ) {
666
+ runTests (t , getKillDSN () , func (dbt * DBTest ) {
559
667
dbt .mustExec ("CREATE TABLE test (v INTEGER)" )
560
668
ctx , cancel := context .WithCancel (context .Background ())
561
669
query := "SELECT 1 FROM test"
0 commit comments