Skip to content

Commit 31781fc

Browse files
committed
fix: use native ScanType from driver and enhance RowBuffer to understand more types
1 parent 3254d43 commit 31781fc

File tree

4 files changed

+488
-76
lines changed

4 files changed

+488
-76
lines changed

dump.go

+89-40
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ import (
1616
/*
1717
Data struct to configure dump behavior
1818
19-
Out: Stream to wite to
20-
Connection: Database connection to dump
21-
IgnoreTables: Mark sensitive tables to ignore
22-
MaxAllowedPacket: Sets the largest packet size to use in backups
23-
LockTables: Lock all tables for the duration of the dump
19+
Out: Stream to wite to
20+
Connection: Database connection to dump
21+
IgnoreTables: Mark sensitive tables to ignore
22+
MaxAllowedPacket: Sets the largest packet size to use in backups
23+
LockTables: Lock all tables for the duration of the dump
2424
*/
2525
type Data struct {
2626
Out io.Writer
@@ -68,7 +68,7 @@ const headerTmpl = `-- Go SQL Dump {{ .DumpVersion }}
6868
/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
6969
/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
7070
/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;
71-
SET NAMES utf8mb4 ;
71+
/*!50503 SET NAMES UTF8 */;
7272
/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;
7373
/*!40103 SET TIME_ZONE='+00:00' */;
7474
/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */;
@@ -99,7 +99,7 @@ const tableTmpl = `
9999
100100
DROP TABLE IF EXISTS {{ .NameEsc }};
101101
/*!40101 SET @saved_cs_client = @@character_set_client */;
102-
SET character_set_client = utf8mb4 ;
102+
/*!50503 SET character_set_client = utf8mb4 */;
103103
{{ .CreateSQL }};
104104
/*!40101 SET character_set_client = @saved_cs_client */;
105105
@@ -296,7 +296,7 @@ func (table *table) CreateSQL() (string, error) {
296296
}
297297

298298
if tableReturn.String != table.Name {
299-
return "", errors.New("Returned table is not the same as requested table")
299+
return "", errors.New("returned table is not the same as requested table")
300300
}
301301

302302
return tableSQL.String, nil
@@ -383,38 +383,11 @@ func (table *table) Init() error {
383383

384384
table.values = make([]interface{}, len(tt))
385385
for i, tp := range tt {
386-
table.values[i] = reflect.New(reflectColumnType(tp)).Interface()
386+
table.values[i] = reflect.New(tp.ScanType()).Interface()
387387
}
388388
return nil
389389
}
390390

391-
func reflectColumnType(tp *sql.ColumnType) reflect.Type {
392-
// reflect for scanable
393-
switch tp.ScanType().Kind() {
394-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
395-
return reflect.TypeOf(sql.NullInt64{})
396-
case reflect.Float32, reflect.Float64:
397-
return reflect.TypeOf(sql.NullFloat64{})
398-
case reflect.String:
399-
return reflect.TypeOf(sql.NullString{})
400-
}
401-
402-
// determine by name
403-
switch tp.DatabaseTypeName() {
404-
case "BLOB", "BINARY":
405-
return reflect.TypeOf(sql.RawBytes{})
406-
case "VARCHAR", "TEXT", "DECIMAL", "JSON":
407-
return reflect.TypeOf(sql.NullString{})
408-
case "BIGINT", "TINYINT", "INT":
409-
return reflect.TypeOf(sql.NullInt64{})
410-
case "DOUBLE":
411-
return reflect.TypeOf(sql.NullFloat64{})
412-
}
413-
414-
// unknown datatype
415-
return tp.ScanType()
416-
}
417-
418391
func (table *table) Next() bool {
419392
if table.rows == nil {
420393
if err := table.Init(); err != nil {
@@ -443,6 +416,30 @@ func (table *table) RowValues() string {
443416
return table.RowBuffer().String()
444417
}
445418

419+
func writeString(b *bytes.Buffer, s string) {
420+
fmt.Fprintf(b, "'%s'", sanitize(s))
421+
}
422+
423+
func writeBool(b *bytes.Buffer, s bool) {
424+
if s {
425+
fmt.Fprintf(b, "1")
426+
} else {
427+
fmt.Fprintf(b, "0")
428+
}
429+
}
430+
431+
func writeBinary(b *bytes.Buffer, s []byte) {
432+
if len(s) == 0 {
433+
b.WriteString(nullType)
434+
} else {
435+
fmt.Fprintf(b, "_binary '%s'", sanitize(string(s)))
436+
}
437+
}
438+
439+
func writeTime(b *bytes.Buffer, s time.Time) {
440+
fmt.Fprintf(b, "'%s'", sanitize(s.UTC().Format(time.DateTime)))
441+
}
442+
446443
func (table *table) RowBuffer() *bytes.Buffer {
447444
var b bytes.Buffer
448445
b.WriteString("(")
@@ -454,9 +451,51 @@ func (table *table) RowBuffer() *bytes.Buffer {
454451
switch s := value.(type) {
455452
case nil:
456453
b.WriteString(nullType)
454+
case *string:
455+
writeString(&b, *s)
457456
case *sql.NullString:
458457
if s.Valid {
459-
fmt.Fprintf(&b, "'%s'", sanitize(s.String))
458+
writeString(&b, s.String)
459+
} else {
460+
b.WriteString(nullType)
461+
}
462+
case *bool:
463+
writeBool(&b, *s)
464+
case *sql.NullBool:
465+
if s.Valid {
466+
writeBool(&b, s.Bool)
467+
} else {
468+
b.WriteString(nullType)
469+
}
470+
case *uint:
471+
fmt.Fprintf(&b, "%d", *s)
472+
case *uint8:
473+
fmt.Fprintf(&b, "%d", *s)
474+
case *uint16:
475+
fmt.Fprintf(&b, "%d", *s)
476+
case *uint32:
477+
fmt.Fprintf(&b, "%d", *s)
478+
case *uint64:
479+
fmt.Fprintf(&b, "%d", *s)
480+
case *int:
481+
fmt.Fprintf(&b, "%d", *s)
482+
case *int8:
483+
fmt.Fprintf(&b, "%d", *s)
484+
case *int16:
485+
fmt.Fprintf(&b, "%d", *s)
486+
case *int32:
487+
fmt.Fprintf(&b, "%d", *s)
488+
case *int64:
489+
fmt.Fprintf(&b, "%d", *s)
490+
case *sql.NullInt16:
491+
if s.Valid {
492+
fmt.Fprintf(&b, "%d", s.Int16)
493+
} else {
494+
b.WriteString(nullType)
495+
}
496+
case *sql.NullInt32:
497+
if s.Valid {
498+
fmt.Fprintf(&b, "%d", s.Int32)
460499
} else {
461500
b.WriteString(nullType)
462501
}
@@ -466,17 +505,27 @@ func (table *table) RowBuffer() *bytes.Buffer {
466505
} else {
467506
b.WriteString(nullType)
468507
}
508+
case *float32:
509+
fmt.Fprintf(&b, "%f", *s)
510+
case *float64:
511+
fmt.Fprintf(&b, "%f", *s)
469512
case *sql.NullFloat64:
470513
if s.Valid {
471514
fmt.Fprintf(&b, "%f", s.Float64)
472515
} else {
473516
b.WriteString(nullType)
474517
}
518+
case *[]byte:
519+
writeBinary(&b, *s)
475520
case *sql.RawBytes:
476-
if len(*s) == 0 {
477-
b.WriteString(nullType)
521+
writeBinary(&b, *s)
522+
case *time.Time:
523+
writeTime(&b, *s)
524+
case *sql.NullTime:
525+
if s.Valid {
526+
writeTime(&b, s.Time)
478527
} else {
479-
fmt.Fprintf(&b, "_binary '%s'", sanitize(string(*s)))
528+
b.WriteString(nullType)
480529
}
481530
default:
482531
fmt.Fprintf(&b, "'%s'", value)

dump_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ func TestCreateTableAllValuesWithNil(t *testing.T) {
228228
AddRow("email", "").
229229
AddRow("name", "")
230230

231-
rows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")).
231+
rows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")).
232232
AddRow(1, nil, "Test Name 1").
233233
AddRow(2, "[email protected]", "Test Name 2").
234234
AddRow(3, "", "Test Name 3")
@@ -266,7 +266,7 @@ func TestCreateTableOk(t *testing.T) {
266266
AddRow("email", "").
267267
AddRow("name", "")
268268

269-
createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")).
269+
createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")).
270270
AddRow(1, nil, "Test Name 1").
271271
AddRow(2, "[email protected]", "Test Name 2")
272272

@@ -294,7 +294,7 @@ func TestCreateTableOk(t *testing.T) {
294294
295295
DROP TABLE IF EXISTS ~Test_Table~;
296296
/*!40101 SET @saved_cs_client = @@character_set_client */;
297-
SET character_set_client = utf8mb4 ;
297+
/*!50503 SET character_set_client = utf8mb4 */;
298298
CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1;
299299
/*!40101 SET character_set_client = @saved_cs_client */;
300300
@@ -325,7 +325,7 @@ func TestCreateTableOkSmallPackets(t *testing.T) {
325325
AddRow("email", "").
326326
AddRow("name", "")
327327

328-
createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")).
328+
createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")).
329329
AddRow(1, nil, "Test Name 1").
330330
AddRow(2, "[email protected]", "Test Name 2")
331331

@@ -353,7 +353,7 @@ func TestCreateTableOkSmallPackets(t *testing.T) {
353353
354354
DROP TABLE IF EXISTS ~Test_Table~;
355355
/*!40101 SET @saved_cs_client = @@character_set_client */;
356-
SET character_set_client = utf8mb4 ;
356+
/*!50503 SET character_set_client = utf8mb4 */;
357357
CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1;
358358
/*!40101 SET character_set_client = @saved_cs_client */;
359359

mysqldump.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Register a new dumper.
1818
*/
1919
func Register(db *sql.DB, dir, format string) (*Data, error) {
2020
if !isDir(dir) {
21-
return nil, errors.New("Invalid directory")
21+
return nil, errors.New("invalid directory")
2222
}
2323

2424
name := time.Now().Format(format)

0 commit comments

Comments
 (0)