diff --git a/connection_go18.go b/connection_go18.go new file mode 100644 index 000000000..51e9ce27c --- /dev/null +++ b/connection_go18.go @@ -0,0 +1,28 @@ +// +build go1.8 + +package mysql + +import ( + "context" + "time" +) + +func (mc *mysqlConn) Ping(ctx context.Context) error { + err := mc.writeCommandPacket(comPing) + if err != nil { + return err + } + + ch := make(chan error) + go func() { + _, err := mc.readResultOK() + ch <- err + }() + select { + case <-ctx.Done(): + mc.netConn.SetReadDeadline(time.Now()) + return ctx.Err() + case err := <-ch: + return err + } +} diff --git a/driver_go18_test.go b/driver_go18_test.go index 93918ad46..3792ead38 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -3,7 +3,9 @@ package mysql import ( + "context" "database/sql" + "database/sql/driver" "fmt" "reflect" "testing" @@ -188,3 +190,18 @@ func TestSkipResults(t *testing.T) { } }) } + +func TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + mysqlDriver := dbt.db.Driver().(driver.Driver) + conn, err := mysqlDriver.Open(dsn) + if err != nil { + dbt.Fatalf("error opening conn: %s", err) + } + pinger := conn.(driver.Pinger) + err = pinger.Ping(context.Background()) + if err != nil { + dbt.Fatalf("error on ping: %s", err) + } + }) +}