Skip to content

Commit eabce5d

Browse files
author
Artem Klyukvin
committed
review 1 fix
1 parent 9ed2daa commit eabce5d

File tree

6 files changed

+149
-23
lines changed

6 files changed

+149
-23
lines changed

connection.go

+13-6
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,18 @@ type mysqlConn struct {
5050
closed atomicBool // set when conn is closed, before closech is closed
5151

5252
// for killing query after timeout
53-
id int
54-
d MySQLDriver
55-
dsn string
53+
id int
54+
d MySQLDriver
5655
}
5756

5857
func (mc *mysqlConn) kill() error {
59-
conn, err := mc.d.Open(mc.dsn)
58+
t := 50 * time.Millisecond
59+
killCfg := *mc.cfg
60+
killCfg.Timeout = t
61+
killCfg.ReadTimeout = t
62+
killCfg.WriteTimeout = t
63+
64+
conn, err := mc.d.Open(killCfg.FormatDSN())
6065
if err != nil {
6166
return err
6267
}
@@ -461,8 +466,10 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
461466
// finish is called when the query has canceled.
462467
func (mc *mysqlConn) cancel(err error) {
463468
mc.canceled.Set(err)
464-
// do not put kill to cleanup to prevent cyclic kills
465-
mc.kill()
469+
if mc.cfg.KillQueryOnTimeout {
470+
// do not put kill to cleanup to prevent cyclic kills
471+
mc.kill()
472+
}
466473
mc.cleanup()
467474
}
468475

connection_go18.go

-4
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,6 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error {
160160
select {
161161
default:
162162
case <-ctx.Done():
163-
killErr := mc.kill()
164-
if killErr != nil {
165-
errLog.Print("failed to kill query: ", killErr)
166-
}
167163
return ctx.Err()
168164
}
169165
if mc.watcher == nil {

driver.go

-2
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
151151
}
152152

153153
mc.d = d
154-
mc.dsn = dsn
155-
156154
return mc, nil
157155
}
158156

driver_go18_test.go

+117-9
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,19 @@ func getQueryProcess(db *sql.DB, dbName, query string) (*dbProcess, error) {
126126
return longProcess, err
127127
}
128128

129-
func killQuery(db *sql.DB, dbName, query string, cancel context.CancelFunc) error {
129+
var expectedKilledErr = fmt.Errorf("process expected to be killed")
130+
131+
func killQuery(db *sql.DB, dbName, query string, timeout time.Duration, cancel context.CancelFunc) error {
130132
process, err := getQueryProcess(db, dbName, query)
131133
if err != nil {
132134
return fmt.Errorf("failed to get mysql process: %v", err)
133135
}
134136
cancel()
135137

136-
end := time.Now().Add(killTimeout)
138+
end := time.Now().Add(timeout)
137139
for time.Now().Before(end) {
138140
if checkProcessExists(dbName, process.ID, db) {
139-
err = fmt.Errorf("process %d expected to be killed", process.ID)
141+
err = expectedKilledErr
140142
time.Sleep(pollTimeout)
141143
} else {
142144
err = nil
@@ -173,7 +175,7 @@ func testCancel(dbt *DBTest, ctx context.Context, cancel context.CancelFunc, que
173175
}()
174176

175177
// it is safe to not use timeouts here since they are inside the killQuery function
176-
err = killQuery(dbt.db, dbname, query, cancel)
178+
err = killQuery(dbt.db, dbname, query, killTimeout, cancel)
177179
if err != nil {
178180
dbt.Error(err)
179181
return
@@ -195,6 +197,62 @@ func testCancel(dbt *DBTest, ctx context.Context, cancel context.CancelFunc, que
195197
tx.Commit()
196198
}
197199

200+
func testCancelNoKill(dbt *DBTest, ctx context.Context, cancel context.CancelFunc, query string, queryFunc func() error) {
201+
tx, err := dbt.db.BeginTx(context.Background(), nil)
202+
if err != nil {
203+
dbt.Fatal(err)
204+
return
205+
}
206+
207+
_, err = tx.Exec("LOCK TABLES test WRITE")
208+
if err != nil {
209+
tx.Rollback()
210+
dbt.Fatal(err)
211+
}
212+
213+
errChan := make(chan error)
214+
go func() {
215+
// This query will be canceled.
216+
err = queryFunc()
217+
if err != nil && err != context.Canceled {
218+
errLog.Print(err)
219+
}
220+
if err != context.Canceled && ctx.Err() != context.Canceled {
221+
errChan <- fmt.Errorf("expected context.Canceled, got %v", err)
222+
return
223+
}
224+
errChan <- nil
225+
}()
226+
227+
// it is safe to not use timeouts here since they are inside the killQuery function
228+
err = killQuery(dbt.db, dbname, query, 500*time.Millisecond, cancel)
229+
if err != expectedKilledErr {
230+
if err == nil {
231+
dbt.Errorf("query kill expected to fail")
232+
} else {
233+
dbt.Errorf(fmt.Sprintf("unexpected error %s", err))
234+
}
235+
}
236+
237+
_, err = tx.Exec("UNLOCK TABLES")
238+
if err != nil {
239+
tx.Rollback()
240+
dbt.Fatal(err)
241+
}
242+
tx.Commit()
243+
244+
<-errChan
245+
}
246+
247+
func getKillDSN() string {
248+
cfg, err := ParseDSN(dsn)
249+
if err != nil {
250+
panic(err)
251+
}
252+
cfg.KillQueryOnTimeout = true
253+
return cfg.FormatDSN()
254+
}
255+
198256
func TestMultiResultSet(t *testing.T) {
199257
type result struct {
200258
values [][]int
@@ -385,12 +443,62 @@ func TestPingContext(t *testing.T) {
385443
})
386444
}
387445

388-
func TestContextCancelExec(t *testing.T) {
446+
func TestContextCancelNoKill(t *testing.T) {
389447
runTests(t, dsn, func(dbt *DBTest) {
390448
dbt.mustExec("CREATE TABLE test (v INTEGER)")
391449
ctx, cancel := context.WithCancel(context.Background())
392450
exec := "INSERT INTO test VALUES(1)"
393451

452+
testCancelNoKill(dbt, ctx, cancel, exec, func() error {
453+
_, err := dbt.db.ExecContext(ctx, exec)
454+
return err
455+
})
456+
457+
// Check how many times the query is executed.
458+
var v int
459+
var err error
460+
for i := 0; i != 3; i++ {
461+
err = nil
462+
if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
463+
dbt.Fatalf("%s", err.Error())
464+
return
465+
}
466+
if v != 1 {
467+
err = fmt.Errorf("expected val to be 1, got %d", v)
468+
}
469+
470+
if err != nil {
471+
time.Sleep(100 * time.Millisecond) // wait while insert is executed after table lock released
472+
}
473+
}
474+
if err != nil {
475+
dbt.Error(err)
476+
return
477+
}
478+
479+
// Context is already canceled, so error should come before execution.
480+
if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil {
481+
dbt.Error("expected error")
482+
} else if err.Error() != "context canceled" {
483+
dbt.Fatalf("unexpected error: %s", err)
484+
}
485+
486+
// The second insert query will fail, so the table has no changes.
487+
if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
488+
dbt.Fatalf("%s", err.Error())
489+
}
490+
if v != 1 {
491+
dbt.Errorf("expected val to be 1, got %d", v)
492+
}
493+
})
494+
}
495+
496+
func TestContextCancelExec(t *testing.T) {
497+
runTests(t, getKillDSN(), func(dbt *DBTest) {
498+
dbt.mustExec("CREATE TABLE test (v INTEGER)")
499+
ctx, cancel := context.WithCancel(context.Background())
500+
exec := "INSERT INTO test VALUES(1)"
501+
394502
testCancel(dbt, ctx, cancel, exec, func() error {
395503
_, err := dbt.db.ExecContext(ctx, exec)
396504
return err
@@ -423,7 +531,7 @@ func TestContextCancelExec(t *testing.T) {
423531
}
424532

425533
func TestContextCancelQuery(t *testing.T) {
426-
runTests(t, dsn, func(dbt *DBTest) {
534+
runTests(t, getKillDSN(), func(dbt *DBTest) {
427535
dbt.mustExec("CREATE TABLE test (v INTEGER)")
428536
ctx, cancel := context.WithCancel(context.Background())
429537
query := "SELECT 1 FROM test"
@@ -501,7 +609,7 @@ func TestContextCancelPrepare(t *testing.T) {
501609
}
502610

503611
func TestContextCancelStmtExec(t *testing.T) {
504-
runTests(t, dsn, func(dbt *DBTest) {
612+
runTests(t, getKillDSN(), func(dbt *DBTest) {
505613
dbt.mustExec("CREATE TABLE test (v INTEGER)")
506614
ctx, cancel := context.WithCancel(context.Background())
507615
exec := "INSERT INTO test VALUES(1)"
@@ -528,7 +636,7 @@ func TestContextCancelStmtExec(t *testing.T) {
528636
}
529637

530638
func TestContextCancelStmtQuery(t *testing.T) {
531-
runTests(t, dsn, func(dbt *DBTest) {
639+
runTests(t, getKillDSN(), func(dbt *DBTest) {
532640
dbt.mustExec("CREATE TABLE test (v INTEGER)")
533641
ctx, cancel := context.WithCancel(context.Background())
534642
query := "SELECT 1 FROM test"
@@ -555,7 +663,7 @@ func TestContextCancelStmtQuery(t *testing.T) {
555663
}
556664

557665
func TestContextCancelBegin(t *testing.T) {
558-
runTests(t, dsn, func(dbt *DBTest) {
666+
runTests(t, getKillDSN(), func(dbt *DBTest) {
559667
dbt.mustExec("CREATE TABLE test (v INTEGER)")
560668
ctx, cancel := context.WithCancel(context.Background())
561669
query := "SELECT 1 FROM test"

dsn.go

+18
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ type Config struct {
5757
MultiStatements bool // Allow multiple statements in one query
5858
ParseTime bool // Parse time values to time.Time
5959
RejectReadOnly bool // Reject read-only connections
60+
KillQueryOnTimeout bool // kill query on the server side if context timed out
6061
}
6162

6263
// NewConfig creates a new Config and sets default values.
@@ -254,6 +255,15 @@ func (cfg *Config) FormatDSN() string {
254255
}
255256
}
256257

258+
if cfg.KillQueryOnTimeout {
259+
if hasParam {
260+
buf.WriteString("&killQueryOnTimeout=true")
261+
} else {
262+
hasParam = true
263+
buf.WriteString("?killQueryOnTimeout=true")
264+
}
265+
}
266+
257267
if cfg.Timeout > 0 {
258268
if hasParam {
259269
buf.WriteString("&timeout=")
@@ -512,6 +522,14 @@ func parseDSNParams(cfg *Config, params string) (err error) {
512522
return errors.New("invalid bool value: " + value)
513523
}
514524

525+
// Kill queries on context timeout
526+
case "killQueryOnTimeout":
527+
var isBool bool
528+
cfg.KillQueryOnTimeout, isBool = readBool(value)
529+
if !isBool {
530+
return errors.New("invalid bool value: " + value)
531+
}
532+
515533
// Strict mode
516534
case "strict":
517535
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")

packets.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,9 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
179179
}
180180

181181
// server version [null terminated string]
182+
// connection id [4 bytes]
182183
idPos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
183184
mc.id = int(binary.LittleEndian.Uint32(data[idPos : idPos+4]))
184-
185-
// connection id [4 bytes]
186185
pos := idPos + 4
187186

188187
// first part of the password cipher [8 bytes]

0 commit comments

Comments
 (0)