Skip to content

Commit 6a4e24e

Browse files
committed
Make BeforeConnect a functional option
1 parent 078d1fc commit 6a4e24e

File tree

3 files changed

+33
-25
lines changed

3 files changed

+33
-25
lines changed

Diff for: connector.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ func newConnector(cfg *Config) *connector {
6666
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6767
var err error
6868

69-
// Invoke BeforeConnect if present, with a copy of the configuration
69+
// Invoke beforeConnect if present, with a copy of the configuration
7070
cfg := c.cfg
71-
if c.cfg.BeforeConnect != nil {
71+
if c.cfg.beforeConnect != nil {
7272
cfg = c.cfg.Clone()
73-
err = c.cfg.BeforeConnect(ctx, cfg)
73+
err = c.cfg.beforeConnect(ctx, cfg)
7474
if err != nil {
7575
return nil, err
7676
}

Diff for: driver_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -2055,10 +2055,10 @@ func TestBeforeConnect(t *testing.T) {
20552055
t.Fatalf("error parsing DSN: %v", err)
20562056
}
20572057

2058-
cfg.BeforeConnect = func(ctx context.Context, c *Config) error {
2058+
cfg.Apply(BeforeConnect(func(ctx context.Context, c *Config) error {
20592059
c.DBName = dbname
20602060
return nil
2061-
}
2061+
}))
20622062

20632063
connector, err := NewConnector(cfg)
20642064
if err != nil {

Diff for: dsn.go

+28-20
Original file line numberDiff line numberDiff line change
@@ -37,24 +37,23 @@ var (
3737
type Config struct {
3838
// non boolean fields
3939

40-
User string // Username
41-
Passwd string // Password (requires User)
42-
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
43-
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
44-
DBName string // Database name
45-
Params map[string]string // Connection parameters
46-
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
47-
Collation string // Connection collation
48-
Loc *time.Location // Location for time.Time values
49-
MaxAllowedPacket int // Max packet size allowed
50-
ServerPubKey string // Server public key name
51-
TLSConfig string // TLS configuration name
52-
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
53-
Timeout time.Duration // Dial timeout
54-
ReadTimeout time.Duration // I/O read timeout
55-
WriteTimeout time.Duration // I/O write timeout
56-
Logger Logger // Logger
57-
BeforeConnect func(context.Context, *Config) error // Invoked before a connection is established
40+
User string // Username
41+
Passwd string // Password (requires User)
42+
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
43+
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
44+
DBName string // Database name
45+
Params map[string]string // Connection parameters
46+
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
47+
Collation string // Connection collation
48+
Loc *time.Location // Location for time.Time values
49+
MaxAllowedPacket int // Max packet size allowed
50+
ServerPubKey string // Server public key name
51+
TLSConfig string // TLS configuration name
52+
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
53+
Timeout time.Duration // Dial timeout
54+
ReadTimeout time.Duration // I/O read timeout
55+
WriteTimeout time.Duration // I/O write timeout
56+
Logger Logger // Logger
5857

5958
// boolean fields
6059

@@ -73,8 +72,9 @@ type Config struct {
7372

7473
// unexported fields. new options should be come here
7574

76-
pubKey *rsa.PublicKey // Server public key
77-
timeTruncate time.Duration // Truncate time.Time values to the specified duration
75+
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
76+
pubKey *rsa.PublicKey // Server public key
77+
timeTruncate time.Duration // Truncate time.Time values to the specified duration
7878
}
7979

8080
// Functional Options Pattern
@@ -114,6 +114,14 @@ func TimeTruncate(d time.Duration) Option {
114114
}
115115
}
116116

117+
// BeforeConnect sets the function to be invoked before a connection is established.
118+
func BeforeConnect(fn func(context.Context, *Config) error) Option {
119+
return func(cfg *Config) error {
120+
cfg.beforeConnect = fn
121+
return nil
122+
}
123+
}
124+
117125
func (cfg *Config) Clone() *Config {
118126
cp := *cfg
119127
if cp.TLS != nil {

0 commit comments

Comments
 (0)