Skip to content

Commit e167565

Browse files
committed
Use the default value converter rather than copy it. (#739)
This change simplifies ConvertValue to only handle the case of uint64 with the high bit set. Other conversions return ErrSkip causing the database/sql default converter to run. As a consequence #739 is fixed
1 parent bc14601 commit e167565

File tree

3 files changed

+97
-109
lines changed

3 files changed

+97
-109
lines changed

driver_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ func TestValuerWithValidation(t *testing.T) {
547547
var out string
548548
var rows *sql.Rows
549549

550+
dbt.mustExec("DROP TABLE IF EXISTS testValuer")
550551
dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8")
551552
dbt.mustExec("INSERT INTO testValuer VALUES (?)", in)
552553

@@ -570,6 +571,10 @@ func TestValuerWithValidation(t *testing.T) {
570571
dbt.Errorf("Failed to check nil")
571572
}
572573

574+
if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", (*testValuerWithValidation)(nil)); err != nil {
575+
dbt.Errorf("Failed to check typed nil")
576+
}
577+
573578
if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil {
574579
dbt.Errorf("Failed to check not valuer")
575580
}

statement.go

+9-24
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ package mysql
1010

1111
import (
1212
"database/sql/driver"
13-
"fmt"
1413
"io"
1514
"reflect"
1615
"strconv"
@@ -132,47 +131,33 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
132131

133132
type converter struct{}
134133

134+
// ConvertValue differs from defaultConverter.ConverValue for uint64 with the high bit set only
135+
// all other conversion requests return driver.ErrSkip to defer to the default converter
135136
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
136137
if driver.IsValue(v) {
137138
return v, nil
138139
}
139140

140-
if v != nil {
141-
if valuer, ok := v.(driver.Valuer); ok {
142-
return valuer.Value()
143-
}
141+
// even when uint64 is the underlying type, a custom Valuer should take precedence
142+
if _, ok := v.(driver.Valuer); ok {
143+
return v, driver.ErrSkip
144144
}
145145

146146
rv := reflect.ValueOf(v)
147147
switch rv.Kind() {
148148
case reflect.Ptr:
149-
// indirect pointers
150149
if rv.IsNil() {
151150
return nil, nil
152151
}
152+
// recursively handle *uint64, **uint64 etc
153153
return c.ConvertValue(rv.Elem().Interface())
154-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
155-
return rv.Int(), nil
156-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
157-
return int64(rv.Uint()), nil
158154
case reflect.Uint64:
159155
u64 := rv.Uint()
160156
if u64 >= 1<<63 {
157+
// The defaultConverter errors in this case - we convert to a string
161158
return strconv.FormatUint(u64, 10), nil
162159
}
163-
return int64(u64), nil
164-
case reflect.Float32, reflect.Float64:
165-
return rv.Float(), nil
166-
case reflect.Bool:
167-
return rv.Bool(), nil
168-
case reflect.Slice:
169-
ek := rv.Type().Elem().Kind()
170-
if ek == reflect.Uint8 {
171-
return rv.Bytes(), nil
172-
}
173-
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
174-
case reflect.String:
175-
return rv.String(), nil
176160
}
177-
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
161+
162+
return v, driver.ErrSkip
178163
}

statement_test.go

+83-85
Original file line numberDiff line numberDiff line change
@@ -9,118 +9,116 @@
99
package mysql
1010

1111
import (
12-
"bytes"
12+
"database/sql/driver"
1313
"testing"
14+
"time"
1415
)
1516

16-
func TestConvertDerivedString(t *testing.T) {
17-
type derived string
17+
func TestValueThatIsValue(t *testing.T) {
18+
now := time.Now()
19+
inputs := []interface{}{nil, float64(1.0), int64(17), "ABC", now}
1820

19-
output, err := converter{}.ConvertValue(derived("value"))
20-
if err != nil {
21-
t.Fatal("Derived string type not convertible", err)
22-
}
23-
24-
if output != "value" {
25-
t.Fatalf("Derived string type not converted, got %#v %T", output, output)
21+
for _, in := range inputs {
22+
out, err := converter{}.ConvertValue(in)
23+
if err != nil {
24+
t.Fatalf("Value %#v %T not needing conversion caused error: %s", in, in, err)
25+
}
26+
if out != in {
27+
t.Fatalf("Value %#v %T altered in conversion got %#v %T", in, in, out, out)
28+
}
2629
}
2730
}
2831

29-
func TestConvertDerivedByteSlice(t *testing.T) {
30-
type derived []uint8
32+
func TestValueThatIsPtrToValue(t *testing.T) {
33+
w := "ABC"
34+
x := &w
35+
y := &x
36+
inputs := []interface{}{x, y}
3137

32-
output, err := converter{}.ConvertValue(derived("value"))
33-
if err != nil {
34-
t.Fatal("Byte slice not convertible", err)
35-
}
36-
37-
if bytes.Compare(output.([]byte), []byte("value")) != 0 {
38-
t.Fatalf("Byte slice not converted, got %#v %T", output, output)
38+
for _, in := range inputs {
39+
out, err := converter{}.ConvertValue(in)
40+
if err != nil {
41+
t.Fatalf("Pointer %#v %T to value not needing conversion caused error: %s", in, in, err)
42+
}
43+
if out != w {
44+
t.Fatalf("Value %#v %T not resolved to string in conversion (got %#v %T)", in, in, out, out)
45+
}
3946
}
4047
}
4148

42-
func TestConvertDerivedUnsupportedSlice(t *testing.T) {
43-
type derived []int
49+
func TestValueThatIsTypedPtrToNil(t *testing.T) {
50+
var w *string
51+
x := &w
52+
y := &x
53+
inputs := []interface{}{x, y}
4454

45-
_, err := converter{}.ConvertValue(derived{1})
46-
if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" {
47-
t.Fatal("Unexpected error", err)
55+
for _, in := range inputs {
56+
out, err := converter{}.ConvertValue(in)
57+
if err != nil {
58+
t.Fatalf("Pointer %#v %T to nil value caused error: %s", in, in, err)
59+
}
60+
if out != nil {
61+
t.Fatalf("Pointer to nil did not Value as nil")
62+
}
4863
}
4964
}
5065

51-
func TestConvertDerivedBool(t *testing.T) {
52-
type derived bool
66+
type implementsValuer uint64
5367

54-
output, err := converter{}.ConvertValue(derived(true))
55-
if err != nil {
56-
t.Fatal("Derived bool type not convertible", err)
57-
}
58-
59-
if output != true {
60-
t.Fatalf("Derived bool type not converted, got %#v %T", output, output)
61-
}
68+
func (me implementsValuer) Value() (driver.Value, error) {
69+
return string(me), nil
6270
}
63-
64-
func TestConvertPointer(t *testing.T) {
65-
str := "value"
66-
67-
output, err := converter{}.ConvertValue(&str)
68-
if err != nil {
69-
t.Fatal("Pointer type not convertible", err)
70-
}
71-
72-
if output != "value" {
73-
t.Fatalf("Pointer type not converted, got %#v %T", output, output)
71+
func TestTypesThatImplementValuerAreSkipped(t *testing.T) {
72+
// Have to test on a uint64 with high bit set - as we skip everything else anyhow
73+
x := implementsValuer(^uint64(0))
74+
y := &x
75+
z := &y
76+
var a *implementsValuer
77+
b := &a
78+
c := &b
79+
inputs := []interface{}{x, y, z, a, b, c}
80+
81+
for _, in := range inputs {
82+
_, err := converter{}.ConvertValue(in)
83+
if err != driver.ErrSkip {
84+
t.Fatalf("Conversion of Valuer implementing type %T not skipped", in)
85+
}
7486
}
7587
}
7688

77-
func TestConvertSignedIntegers(t *testing.T) {
78-
values := []interface{}{
79-
int8(-42),
80-
int16(-42),
81-
int32(-42),
82-
int64(-42),
83-
int(-42),
84-
}
85-
86-
for _, value := range values {
87-
output, err := converter{}.ConvertValue(value)
88-
if err != nil {
89-
t.Fatalf("%T type not convertible %s", value, err)
90-
}
91-
92-
if output != int64(-42) {
93-
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
89+
func TestTypesThatAreNotValuesAreSkipped(t *testing.T) {
90+
type derived1 string // convertable
91+
type derived2 []uint8 // convertable
92+
type derived3 []int // not convertable
93+
type derived4 uint64 // without the high bit set
94+
inputs := []interface{}{derived1("ABC"), derived2([]uint8{'A', 'B'}), derived3([]int{17, 32}), derived3(nil), derived4(26)}
95+
96+
for _, in := range inputs {
97+
_, err := converter{}.ConvertValue(in)
98+
if err != driver.ErrSkip {
99+
t.Fatalf("Conversion of non-value value %#v %T not skipped", in, in)
94100
}
95101
}
96102
}
97103

98-
func TestConvertUnsignedIntegers(t *testing.T) {
99-
values := []interface{}{
100-
uint8(42),
101-
uint16(42),
102-
uint32(42),
103-
uint64(42),
104-
uint(42),
105-
}
104+
func TestConvertLargeUnsignedIntegers(t *testing.T) {
105+
type derived uint64
106+
type derived2 *uint64
107+
v := ^uint64(0)
108+
w := &v
109+
x := derived(v)
110+
y := &x
111+
z := derived2(w)
106112

107-
for _, value := range values {
108-
output, err := converter{}.ConvertValue(value)
113+
inputs := []interface{}{v, w, x, y, z}
114+
115+
for _, in := range inputs {
116+
out, err := converter{}.ConvertValue(in)
109117
if err != nil {
110-
t.Fatalf("%T type not convertible %s", value, err)
118+
t.Fatalf("uint64 high-bit not convertible for type %T", in)
111119
}
112-
113-
if output != int64(42) {
114-
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
120+
if out != "18446744073709551615" {
121+
t.Fatalf("uint64 high-bit not converted, got %#v %T", out, out)
115122
}
116123
}
117-
118-
output, err := converter{}.ConvertValue(^uint64(0))
119-
if err != nil {
120-
t.Fatal("uint64 high-bit not convertible", err)
121-
}
122-
123-
if output != "18446744073709551615" {
124-
t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output)
125-
}
126124
}

0 commit comments

Comments
 (0)