From 82e54bda545e5a157b0c1a18fec2fbb28faaf511 Mon Sep 17 00:00:00 2001 From: Ren Hao Date: Thu, 1 Aug 2019 09:12:20 +0800 Subject: [PATCH 1/2] Support decimal for sql execution. --- decimal.go | 9 +++++++++ packets.go | 15 +++++++++++++++ statement.go | 3 +++ 3 files changed, 27 insertions(+) create mode 100644 decimal.go diff --git a/decimal.go b/decimal.go new file mode 100644 index 000000000..b4015fa3c --- /dev/null +++ b/decimal.go @@ -0,0 +1,9 @@ +package mysql + +import "database/sql/driver" + +type Decimal string + +func (d Decimal) Value() (driver.Value, error) { + return d, nil +} diff --git a/packets.go b/packets.go index 30b3352c2..c3b21d63f 100644 --- a/packets.go +++ b/packets.go @@ -1102,6 +1102,21 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } } + case Decimal: + paramTypes[i+i] = byte(fieldTypeNewDecimal) + paramTypes[i+i+1] = 0x00 + + if len(v) < longDataSize { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + return err + } + } + case time.Time: paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 diff --git a/statement.go b/statement.go index f7e370939..3b4845805 100644 --- a/statement.go +++ b/statement.go @@ -146,6 +146,9 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if err != nil { return nil, err } + if _, ok = sv.(Decimal); ok { + return sv, nil + } if !driver.IsValue(sv) { return nil, fmt.Errorf("non-Value type %T returned from Value", sv) } From e4cdc26b83521d9f768ab7f2b31d696441301b9c Mon Sep 17 00:00:00 2001 From: Ren Hao Date: Thu, 15 Aug 2019 12:14:13 +0800 Subject: [PATCH 2/2] Update decimal type support. --- decimal.go | 8 ++------ packets.go | 8 ++++---- statement.go | 7 ++++--- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/decimal.go b/decimal.go index b4015fa3c..d23023309 100644 --- a/decimal.go +++ b/decimal.go @@ -1,9 +1,5 @@ package mysql -import "database/sql/driver" - -type Decimal string - -func (d Decimal) Value() (driver.Value, error) { - return d, nil +type Decimal interface { + DecimalString() string } diff --git a/packets.go b/packets.go index c3b21d63f..8b3197550 100644 --- a/packets.go +++ b/packets.go @@ -1106,13 +1106,13 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = byte(fieldTypeNewDecimal) paramTypes[i+i+1] = 0x00 - if len(v) < longDataSize { + if len(v.DecimalString()) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, - uint64(len(v)), + uint64(len(v.DecimalString())), ) - paramValues = append(paramValues, v...) + paramValues = append(paramValues, v.DecimalString()...) } else { - if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + if err := stmt.writeCommandLongData(i, []byte(v.DecimalString())); err != nil { return err } } diff --git a/statement.go b/statement.go index 3b4845805..039cfafc9 100644 --- a/statement.go +++ b/statement.go @@ -141,14 +141,15 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return v, nil } + if _, ok := v.(Decimal); ok { + return v, nil + } + if vr, ok := v.(driver.Valuer); ok { sv, err := callValuerValue(vr) if err != nil { return nil, err } - if _, ok = sv.(Decimal); ok { - return sv, nil - } if !driver.IsValue(sv) { return nil, fmt.Errorf("non-Value type %T returned from Value", sv) }