From 0498ea5849d11973395e134197221f3b3580abb1 Mon Sep 17 00:00:00 2001 From: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Date: Wed, 2 Aug 2023 16:25:03 -0700 Subject: [PATCH 1/4] Add BeforeConnect callback to configuration object This can be used to alter the connection options for each connection, right before it's established --- connector.go | 12 +++++++++++- driver_test.go | 34 ++++++++++++++++++++++++++++++++++ dsn.go | 3 +++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/connector.go b/connector.go index 7e0b16734..d77dfc80d 100644 --- a/connector.go +++ b/connector.go @@ -66,12 +66,22 @@ func newConnector(cfg *Config) (*connector, error) { func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { var err error + // Invoke BeforeConnect if present, with a copy of the configuration + cfg := c.cfg + if c.cfg.BeforeConnect != nil { + cfg = c.cfg.Clone() + err = c.cfg.BeforeConnect(ctx, cfg) + if err != nil { + return nil, err + } + } + // New mysqlConn mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), - cfg: c.cfg, + cfg: cfg, connector: c, } mc.parseTime = mc.cfg.ParseTime diff --git a/driver_test.go b/driver_test.go index 2748870b7..66eb12c45 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1965,6 +1965,40 @@ func TestCustomDial(t *testing.T) { } } +func TestBeforeConnect(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // dbname is set in the BeforeConnect handle + cfg, err := ParseDSN(fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, "_")) + if err != nil { + t.Fatalf("error parsing DSN: %v", err) + } + + cfg.BeforeConnect = func(ctx context.Context, c *Config) error { + c.DBName = dbname + return nil + } + + connector, err := NewConnector(cfg) + if err != nil { + t.Fatalf("error creating connector: %v", err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + var connectedDb string + err = db.QueryRow("SELECT DATABASE();").Scan(&connectedDb) + if err != nil { + t.Fatalf("error executing query: %v", err) + } + if connectedDb != dbname { + t.Fatalf("expected to connect to DB %s, but connected to %s instead", dbname, connectedDb) + } +} + func TestSQLInjection(t *testing.T) { createTest := func(arg string) func(dbt *DBTest) { return func(dbt *DBTest) { diff --git a/dsn.go b/dsn.go index 380ca9570..49aab6939 100644 --- a/dsn.go +++ b/dsn.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "context" "crypto/rsa" "crypto/tls" "errors" @@ -65,6 +66,8 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections + + BeforeConnect func(context.Context, *Config) error // Invoked before a connection is established } // NewConfig creates a new Config and sets default values. From 4b862410c834956b1a457066f90fbc67b9f87348 Mon Sep 17 00:00:00 2001 From: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Date: Wed, 2 Aug 2023 16:26:48 -0700 Subject: [PATCH 2/4] Updated AUTHORS --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index 29e08b0ca..40e5f428f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -124,6 +124,7 @@ GitHub Inc. Google Inc. InfoSum Ltd. Keybase Inc. +Microsoft Corp. Multiplay Ltd. Percona LLC Pivotal Inc. From ef007e02702b1605501adc3abfefc73862eec646 Mon Sep 17 00:00:00 2001 From: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Date: Tue, 12 Dec 2023 07:26:06 -0800 Subject: [PATCH 3/4] Changed per review feedback --- dsn.go | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/dsn.go b/dsn.go index 30fc68bc1..a044de577 100644 --- a/dsn.go +++ b/dsn.go @@ -35,24 +35,25 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") - Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") - DBName string // Database name - Params map[string]string // Connection parameters - ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key - TLSConfig string // TLS configuration name - TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout - Logger Logger // Logger + User string // Username + Passwd string // Password (requires User) + Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") + Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") + DBName string // Database name + Params map[string]string // Connection parameters + ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + Logger Logger // Logger + BeforeConnect func(context.Context, *Config) error // Invoked before a connection is established AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -66,8 +67,6 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - - BeforeConnect func(context.Context, *Config) error // Invoked before a connection is established } // NewConfig creates a new Config and sets default values. From 6a4e24ee11ef6a81c552d1f4bbbe4a8e374bfb74 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 7 Mar 2024 00:49:16 +0900 Subject: [PATCH 4/4] Make BeforeConnect a functional option --- connector.go | 6 +++--- driver_test.go | 4 ++-- dsn.go | 48 ++++++++++++++++++++++++++++-------------------- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/connector.go b/connector.go index a00bd5828..a0ee62839 100644 --- a/connector.go +++ b/connector.go @@ -66,11 +66,11 @@ func newConnector(cfg *Config) *connector { func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { var err error - // Invoke BeforeConnect if present, with a copy of the configuration + // Invoke beforeConnect if present, with a copy of the configuration cfg := c.cfg - if c.cfg.BeforeConnect != nil { + if c.cfg.beforeConnect != nil { cfg = c.cfg.Clone() - err = c.cfg.BeforeConnect(ctx, cfg) + err = c.cfg.beforeConnect(ctx, cfg) if err != nil { return nil, err } diff --git a/driver_test.go b/driver_test.go index 3ae8b231b..001957244 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2055,10 +2055,10 @@ func TestBeforeConnect(t *testing.T) { t.Fatalf("error parsing DSN: %v", err) } - cfg.BeforeConnect = func(ctx context.Context, c *Config) error { + cfg.Apply(BeforeConnect(func(ctx context.Context, c *Config) error { c.DBName = dbname return nil - } + })) connector, err := NewConnector(cfg) if err != nil { diff --git a/dsn.go b/dsn.go index e7990358e..65f5a0242 100644 --- a/dsn.go +++ b/dsn.go @@ -37,24 +37,23 @@ var ( type Config struct { // non boolean fields - User string // Username - Passwd string // Password (requires User) - Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") - Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") - DBName string // Database name - Params map[string]string // Connection parameters - ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - TLSConfig string // TLS configuration name - TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout - Logger Logger // Logger - BeforeConnect func(context.Context, *Config) error // Invoked before a connection is established + User string // Username + Passwd string // Password (requires User) + Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") + Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") + DBName string // Database name + Params map[string]string // Connection parameters + ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + TLSConfig string // TLS configuration name + TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + Logger Logger // Logger // boolean fields @@ -73,8 +72,9 @@ type Config struct { // unexported fields. new options should be come here - pubKey *rsa.PublicKey // Server public key - timeTruncate time.Duration // Truncate time.Time values to the specified duration + beforeConnect func(context.Context, *Config) error // Invoked before a connection is established + pubKey *rsa.PublicKey // Server public key + timeTruncate time.Duration // Truncate time.Time values to the specified duration } // Functional Options Pattern @@ -114,6 +114,14 @@ func TimeTruncate(d time.Duration) Option { } } +// BeforeConnect sets the function to be invoked before a connection is established. +func BeforeConnect(fn func(context.Context, *Config) error) Option { + return func(cfg *Config) error { + cfg.beforeConnect = fn + return nil + } +} + func (cfg *Config) Clone() *Config { cp := *cfg if cp.TLS != nil {