Skip to content

Commit 6964272

Browse files
authored
Make TimeTruncate functional option (#1552)
1 parent 097fe6e commit 6964272

File tree

5 files changed

+45
-13
lines changed

5 files changed

+45
-13
lines changed

Diff for: connection.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
251251
buf = append(buf, "'0000-00-00'"...)
252252
} else {
253253
buf = append(buf, '\'')
254-
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate)
254+
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
255255
if err != nil {
256256
return "", err
257257
}

Diff for: dsn.go

+40-7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ var (
3434
// If a new Config is created instead of being parsed from a DSN string,
3535
// the NewConfig function should be used, which sets default values.
3636
type Config struct {
37+
// non boolean fields
38+
3739
User string // Username
3840
Passwd string // Password (requires User)
3941
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
@@ -45,15 +47,15 @@ type Config struct {
4547
Loc *time.Location // Location for time.Time values
4648
MaxAllowedPacket int // Max packet size allowed
4749
ServerPubKey string // Server public key name
48-
pubKey *rsa.PublicKey // Server public key
4950
TLSConfig string // TLS configuration name
5051
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
51-
TimeTruncate time.Duration // Truncate time.Time values to the specified duration
5252
Timeout time.Duration // Dial timeout
5353
ReadTimeout time.Duration // I/O read timeout
5454
WriteTimeout time.Duration // I/O write timeout
5555
Logger Logger // Logger
5656

57+
// boolean fields
58+
5759
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
5860
AllowCleartextPasswords bool // Allows the cleartext client side plugin
5961
AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
@@ -66,17 +68,48 @@ type Config struct {
6668
MultiStatements bool // Allow multiple statements in one query
6769
ParseTime bool // Parse time values to time.Time
6870
RejectReadOnly bool // Reject read-only connections
71+
72+
// unexported fields. new options should be come here
73+
74+
pubKey *rsa.PublicKey // Server public key
75+
timeTruncate time.Duration // Truncate time.Time values to the specified duration
6976
}
7077

78+
// Functional Options Pattern
79+
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
80+
type Option func(*Config) error
81+
7182
// NewConfig creates a new Config and sets default values.
7283
func NewConfig() *Config {
73-
return &Config{
84+
cfg := &Config{
7485
Loc: time.UTC,
7586
MaxAllowedPacket: defaultMaxAllowedPacket,
7687
Logger: defaultLogger,
7788
AllowNativePasswords: true,
7889
CheckConnLiveness: true,
7990
}
91+
92+
return cfg
93+
}
94+
95+
// Apply applies the given options to the Config object.
96+
func (c *Config) Apply(opts ...Option) error {
97+
for _, opt := range opts {
98+
err := opt(c)
99+
if err != nil {
100+
return err
101+
}
102+
}
103+
return nil
104+
}
105+
106+
// TimeTruncate sets the time duration to truncate time.Time values in
107+
// query parameters.
108+
func TimeTruncate(d time.Duration) Option {
109+
return func(cfg *Config) error {
110+
cfg.timeTruncate = d
111+
return nil
112+
}
80113
}
81114

82115
func (cfg *Config) Clone() *Config {
@@ -263,8 +296,8 @@ func (cfg *Config) FormatDSN() string {
263296
writeDSNParam(&buf, &hasParam, "parseTime", "true")
264297
}
265298

266-
if cfg.TimeTruncate > 0 {
267-
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.TimeTruncate.String())
299+
if cfg.timeTruncate > 0 {
300+
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String())
268301
}
269302

270303
if cfg.ReadTimeout > 0 {
@@ -509,9 +542,9 @@ func parseDSNParams(cfg *Config, params string) (err error) {
509542

510543
// time.Time truncation
511544
case "timeTruncate":
512-
cfg.TimeTruncate, err = time.ParseDuration(value)
545+
cfg.timeTruncate, err = time.ParseDuration(value)
513546
if err != nil {
514-
return
547+
return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err)
515548
}
516549

517550
// I/O read Timeout

Diff for: dsn_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ var testDSNs = []struct {
7676
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
7777
}, {
7878
"user:password@/dbname?loc=UTC&timeout=30s&parseTime=true&timeTruncate=1h",
79-
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TimeTruncate: time.Hour},
79+
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, timeTruncate: time.Hour},
8080
},
8181
}
8282

Diff for: packets.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
11721172
if v.IsZero() {
11731173
b = append(b, "0000-00-00"...)
11741174
} else {
1175-
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate)
1175+
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
11761176
if err != nil {
11771177
return err
11781178
}

Diff for: result.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ import "database/sql/driver"
1515
// This is accessible by executing statements using sql.Conn.Raw() and
1616
// downcasting the returned result:
1717
//
18-
// res, err := rawConn.Exec(...)
19-
// res.(mysql.Result).AllRowsAffected()
20-
//
18+
// res, err := rawConn.Exec(...)
19+
// res.(mysql.Result).AllRowsAffected()
2120
type Result interface {
2221
driver.Result
2322
// AllRowsAffected returns a slice containing the affected rows for each

0 commit comments

Comments
 (0)