Skip to content

Commit b8fac96

Browse files
committed
Allow to change (or disable) the default driver name for registration
A link variable now allows to change or disable the name of the driver that is automatically registered with database/sql: Change: go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom" Disable: go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=" In the same way, a variable overridable at link time is also provided to override the driver name used in the test suite. This allows to run our test suite on another driver. go build "-ldflags=-X github.com/go-sql-driver/mysql.driverNameTest=custom"
1 parent 278a0b9 commit b8fac96

File tree

3 files changed

+30
-14
lines changed

3 files changed

+30
-14
lines changed

Diff for: benchmark_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
4848

4949
func initDB(b *testing.B, queries ...string) *sql.DB {
5050
tb := (*TB)(b)
51-
db := tb.checkDB(sql.Open("mysql", dsn))
51+
db := tb.checkDB(sql.Open(driverNameTest, dsn))
5252
for _, query := range queries {
5353
if _, err := db.Exec(query); err != nil {
5454
b.Fatalf("error on %q: %v", query, err)
@@ -105,7 +105,7 @@ func BenchmarkExec(b *testing.B) {
105105
tb := (*TB)(b)
106106
b.StopTimer()
107107
b.ReportAllocs()
108-
db := tb.checkDB(sql.Open("mysql", dsn))
108+
db := tb.checkDB(sql.Open(driverNameTest, dsn))
109109
db.SetMaxIdleConns(concurrencyLevel)
110110
defer db.Close()
111111

@@ -151,7 +151,7 @@ func BenchmarkRoundtripTxt(b *testing.B) {
151151
sampleString := string(sample)
152152
b.ReportAllocs()
153153
tb := (*TB)(b)
154-
db := tb.checkDB(sql.Open("mysql", dsn))
154+
db := tb.checkDB(sql.Open(driverNameTest, dsn))
155155
defer db.Close()
156156
b.StartTimer()
157157
var result string
@@ -184,7 +184,7 @@ func BenchmarkRoundtripBin(b *testing.B) {
184184
sample, min, max := initRoundtripBenchmarks()
185185
b.ReportAllocs()
186186
tb := (*TB)(b)
187-
db := tb.checkDB(sql.Open("mysql", dsn))
187+
db := tb.checkDB(sql.Open(driverNameTest, dsn))
188188
defer db.Close()
189189
stmt := tb.checkStmt(db.Prepare("SELECT ?"))
190190
defer stmt.Close()

Diff for: driver.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,14 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
9090
return c.Connect(context.Background())
9191
}
9292

93+
// This variable can be replaced with -ldflags like below:
94+
// go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom"
95+
var driverName = "mysql"
96+
9397
func init() {
94-
sql.Register("mysql", &MySQLDriver{})
98+
if driverName != "" {
99+
sql.Register(driverName, &MySQLDriver{})
100+
}
95101
}
96102

97103
// NewConnector returns new driver.Connector.

Diff for: driver_test.go

+19-9
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ import (
3131
"time"
3232
)
3333

34+
// This variable can be replaced with -ldflags like below:
35+
// go test "-ldflags=-X github.com/go-sql-driver/mysql_test.driverNameTest=custom"
36+
var driverNameTest string
37+
38+
func init() {
39+
if driverNameTest == "" {
40+
driverNameTest = driverName
41+
}
42+
}
43+
3444
// Ensure that all the driver interfaces are implemented
3545
var (
3646
_ driver.Rows = &binaryRows{}
@@ -111,7 +121,7 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT
111121
dsn += "&multiStatements=true"
112122
var db *sql.DB
113123
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
114-
db, err = sql.Open("mysql", dsn)
124+
db, err = sql.Open(driverNameTest, dsn)
115125
if err != nil {
116126
t.Fatalf("error connecting: %s", err.Error())
117127
}
@@ -130,7 +140,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
130140
t.Skipf("MySQL server not running on %s", netAddr)
131141
}
132142

133-
db, err := sql.Open("mysql", dsn)
143+
db, err := sql.Open(driverNameTest, dsn)
134144
if err != nil {
135145
t.Fatalf("error connecting: %s", err.Error())
136146
}
@@ -141,7 +151,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
141151
dsn2 := dsn + "&interpolateParams=true"
142152
var db2 *sql.DB
143153
if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
144-
db2, err = sql.Open("mysql", dsn2)
154+
db2, err = sql.Open(driverNameTest, dsn2)
145155
if err != nil {
146156
t.Fatalf("error connecting: %s", err.Error())
147157
}
@@ -1917,7 +1927,7 @@ func testDialError(t *testing.T, dialErr error, expectErr error) {
19171927
return nil, dialErr
19181928
})
19191929

1920-
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
1930+
db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
19211931
if err != nil {
19221932
t.Fatalf("error connecting: %s", err.Error())
19231933
}
@@ -1956,7 +1966,7 @@ func TestCustomDial(t *testing.T) {
19561966
return d.DialContext(ctx, prot, addr)
19571967
})
19581968

1959-
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
1969+
db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
19601970
if err != nil {
19611971
t.Fatalf("error connecting: %s", err.Error())
19621972
}
@@ -2054,7 +2064,7 @@ func TestUnixSocketAuthFail(t *testing.T) {
20542064
}
20552065
t.Logf("socket: %s", socket)
20562066
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname)
2057-
db, err := sql.Open("mysql", badDSN)
2067+
db, err := sql.Open(driverNameTest, badDSN)
20582068
if err != nil {
20592069
t.Fatalf("error connecting: %s", err.Error())
20602070
}
@@ -2243,7 +2253,7 @@ func TestEmptyPassword(t *testing.T) {
22432253
}
22442254

22452255
dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname)
2246-
db, err := sql.Open("mysql", dsn)
2256+
db, err := sql.Open(driverNameTest, dsn)
22472257
if err == nil {
22482258
defer db.Close()
22492259
err = db.Ping()
@@ -3210,7 +3220,7 @@ func TestConnectorObeysDialTimeouts(t *testing.T) {
32103220
return d.DialContext(ctx, prot, addr)
32113221
})
32123222

3213-
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
3223+
db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
32143224
if err != nil {
32153225
t.Fatalf("error connecting: %s", err.Error())
32163226
}
@@ -3375,7 +3385,7 @@ func TestConnectionAttributes(t *testing.T) {
33753385

33763386
var db *sql.DB
33773387
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
3378-
db, err = sql.Open("mysql", dsn)
3388+
db, err = sql.Open(driverNameTest, dsn)
33793389
if err != nil {
33803390
t.Fatalf("error connecting: %s", err.Error())
33813391
}

0 commit comments

Comments
 (0)