Skip to content

Commit 0498ea5

Browse files
committed
Add BeforeConnect callback to configuration object
This can be used to alter the connection options for each connection, right before it's established
1 parent 0b18dac commit 0498ea5

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

connector.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,22 @@ func newConnector(cfg *Config) (*connector, error) {
6666
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6767
var err error
6868

69+
// Invoke BeforeConnect if present, with a copy of the configuration
70+
cfg := c.cfg
71+
if c.cfg.BeforeConnect != nil {
72+
cfg = c.cfg.Clone()
73+
err = c.cfg.BeforeConnect(ctx, cfg)
74+
if err != nil {
75+
return nil, err
76+
}
77+
}
78+
6979
// New mysqlConn
7080
mc := &mysqlConn{
7181
maxAllowedPacket: maxPacketSize,
7282
maxWriteSize: maxPacketSize - 1,
7383
closech: make(chan struct{}),
74-
cfg: c.cfg,
84+
cfg: cfg,
7585
connector: c,
7686
}
7787
mc.parseTime = mc.cfg.ParseTime

driver_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,40 @@ func TestCustomDial(t *testing.T) {
19651965
}
19661966
}
19671967

1968+
func TestBeforeConnect(t *testing.T) {
1969+
if !available {
1970+
t.Skipf("MySQL server not running on %s", netAddr)
1971+
}
1972+
1973+
// dbname is set in the BeforeConnect handle
1974+
cfg, err := ParseDSN(fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, "_"))
1975+
if err != nil {
1976+
t.Fatalf("error parsing DSN: %v", err)
1977+
}
1978+
1979+
cfg.BeforeConnect = func(ctx context.Context, c *Config) error {
1980+
c.DBName = dbname
1981+
return nil
1982+
}
1983+
1984+
connector, err := NewConnector(cfg)
1985+
if err != nil {
1986+
t.Fatalf("error creating connector: %v", err)
1987+
}
1988+
1989+
db := sql.OpenDB(connector)
1990+
defer db.Close()
1991+
1992+
var connectedDb string
1993+
err = db.QueryRow("SELECT DATABASE();").Scan(&connectedDb)
1994+
if err != nil {
1995+
t.Fatalf("error executing query: %v", err)
1996+
}
1997+
if connectedDb != dbname {
1998+
t.Fatalf("expected to connect to DB %s, but connected to %s instead", dbname, connectedDb)
1999+
}
2000+
}
2001+
19682002
func TestSQLInjection(t *testing.T) {
19692003
createTest := func(arg string) func(dbt *DBTest) {
19702004
return func(dbt *DBTest) {

dsn.go

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"bytes"
13+
"context"
1314
"crypto/rsa"
1415
"crypto/tls"
1516
"errors"
@@ -65,6 +66,8 @@ type Config struct {
6566
MultiStatements bool // Allow multiple statements in one query
6667
ParseTime bool // Parse time values to time.Time
6768
RejectReadOnly bool // Reject read-only connections
69+
70+
BeforeConnect func(context.Context, *Config) error // Invoked before a connection is established
6871
}
6972

7073
// NewConfig creates a new Config and sets default values.

0 commit comments

Comments
 (0)