diff --git a/AUTHORS b/AUTHORS index 301ce573f..14e8398fd 100644 --- a/AUTHORS +++ b/AUTHORS @@ -14,6 +14,7 @@ Aaron Hopkins Achille Roussel Alexey Palazhchenko +Andrew Reid Arne Hormann Asta Xie Bulat Gaifullin diff --git a/driver_go18_test.go b/driver_go18_test.go index e461455dd..afd5694ec 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -796,3 +796,11 @@ func TestRowsColumnTypes(t *testing.T) { }) } } + +func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (value VARCHAR(255))") + dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil)) + // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() + }) +} diff --git a/statement.go b/statement.go index 98e57bcd8..ce7fe4cd0 100644 --- a/statement.go +++ b/statement.go @@ -132,15 +132,25 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { type converter struct{} +// ConvertValue mirrors the reference/default converter in database/sql/driver +// with _one_ exception. We support uint64 with their high bit and the default +// implementation does not. This function should be kept in sync with +// database/sql/driver defaultConverter.ConvertValue() except for that +// deliberate difference. func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if driver.IsValue(v) { return v, nil } - if v != nil { - if valuer, ok := v.(driver.Valuer); ok { - return valuer.Value() + if vr, ok := v.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return nil, err + } + if !driver.IsValue(sv) { + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) } + return sv, nil } rv := reflect.ValueOf(v) @@ -149,8 +159,9 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { // indirect pointers if rv.IsNil() { return nil, nil + } else { + return c.ConvertValue(rv.Elem().Interface()) } - return c.ConvertValue(rv.Elem().Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return rv.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: @@ -176,3 +187,25 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { } return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This is an exact copy of the same-named unexported function from the +// database/sql package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +}