diff --git a/decimal.go b/decimal.go new file mode 100644 index 000000000..d23023309 --- /dev/null +++ b/decimal.go @@ -0,0 +1,5 @@ +package mysql + +type Decimal interface { + DecimalString() string +} diff --git a/packets.go b/packets.go index 30b3352c2..8b3197550 100644 --- a/packets.go +++ b/packets.go @@ -1102,6 +1102,21 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } } + case Decimal: + paramTypes[i+i] = byte(fieldTypeNewDecimal) + paramTypes[i+i+1] = 0x00 + + if len(v.DecimalString()) < longDataSize { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v.DecimalString())), + ) + paramValues = append(paramValues, v.DecimalString()...) + } else { + if err := stmt.writeCommandLongData(i, []byte(v.DecimalString())); err != nil { + return err + } + } + case time.Time: paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 diff --git a/statement.go b/statement.go index f7e370939..039cfafc9 100644 --- a/statement.go +++ b/statement.go @@ -141,6 +141,10 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return v, nil } + if _, ok := v.(Decimal); ok { + return v, nil + } + if vr, ok := v.(driver.Valuer); ok { sv, err := callValuerValue(vr) if err != nil {