Skip to content

Export ParseDSN and Config #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ func BenchmarkRoundtripBin(b *testing.B) {

func BenchmarkInterpolation(b *testing.B) {
mc := &mysqlConn{
cfg: &config{
interpolateParams: true,
loc: time.UTC,
cfg: &Config{
InterpolateParams: true,
Loc: time.UTC,
},
maxPacketAllowed: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
Expand Down
53 changes: 6 additions & 47 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
package mysql

import (
"crypto/tls"
"database/sql/driver"
"errors"
"net"
"strconv"
"strings"
Expand All @@ -23,7 +21,7 @@ type mysqlConn struct {
netConn net.Conn
affectedRows uint64
insertId uint64
cfg *config
cfg *Config
maxPacketAllowed int
maxWriteSize int
flags clientFlag
Expand All @@ -33,28 +31,9 @@ type mysqlConn struct {
strict bool
}

type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
tls *tls.Config
timeout time.Duration
collation uint8
allowAllFiles bool
allowOldPasswords bool
allowCleartextPasswords bool
clientFoundRows bool
columnsWithAlias bool
interpolateParams bool
}

// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.params {
for param, val := range mc.cfg.Params {
switch param {
// Charset
case "charset":
Expand All @@ -70,27 +49,6 @@ func (mc *mysqlConn) handleParams() (err error) {
return
}

// time.Time parsing
case "parseTime":
var isBool bool
mc.parseTime, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}

// Strict mode
case "strict":
var isBool bool
mc.strict, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}

// Compression
case "compress":
err = errors.New("Compression not implemented yet")
return

// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
Expand Down Expand Up @@ -217,7 +175,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
v := v.In(mc.cfg.loc)
v := v.In(mc.cfg.Loc)
v = v.Add(time.Nanosecond * 500) // To round under microsecond
year := v.Year()
year100 := year / 100
Expand Down Expand Up @@ -298,7 +256,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.interpolateParams {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
Expand Down Expand Up @@ -349,7 +307,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.interpolateParams {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
Expand Down Expand Up @@ -395,6 +353,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}

if resLen > 0 {
// Columns
Expand Down
16 changes: 9 additions & 7 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
maxPacketAllowed: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
}
mc.cfg, err = parseDSN(dsn)
mc.cfg, err = ParseDSN(dsn)
if err != nil {
return nil, err
}
mc.parseTime = mc.cfg.ParseTime
mc.strict = mc.cfg.Strict

// Connect to Server
if dial, ok := dials[mc.cfg.net]; ok {
mc.netConn, err = dial(mc.cfg.addr)
if dial, ok := dials[mc.cfg.Net]; ok {
mc.netConn, err = dial(mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.timeout}
mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -136,15 +138,15 @@ func handleAuthResult(mc *mysqlConn, cipher []byte) error {
}

// Retry auth if configured to do so.
if mc.cfg.allowOldPasswords && err == ErrOldPassword {
if mc.cfg.AllowOldPasswords && err == ErrOldPassword {
// Retry with old authentication method. Note: there are edge cases
// where this should work but doesn't; this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
if err = mc.writeOldAuthPacket(cipher); err != nil {
return err
}
err = mc.readResultOK()
} else if mc.cfg.allowCleartextPasswords && err == ErrCleartextPassword {
} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword {
// Retry with clear text password for
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
Expand Down
2 changes: 1 addition & 1 deletion driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {

dsn2 := dsn + "&interpolateParams=true"
var db2 *sql.DB
if _, err := parseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
db2, err = sql.Open("mysql", dsn2)
if err != nil {
t.Fatalf("Error connecting: %s", err.Error())
Expand Down
Loading