Skip to content

Commit 59d433e

Browse files
committed
Fixed rebinding Bug
+ more clean up
1 parent e4b1048 commit 59d433e

File tree

2 files changed

+77
-99
lines changed

2 files changed

+77
-99
lines changed

packets.go

Lines changed: 75 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ func (mc *mysqlConn) readPacket() (data []byte, e error) {
4949
data = make([]byte, pktLen)
5050
n, e := mc.netConn.Read(data)
5151
if e != nil || n != int(pktLen) {
52-
fmt.Println(e)
5352
e = driver.ErrBadConn
5453
return
5554
}
@@ -77,7 +76,6 @@ func (mc *mysqlConn) writePacket(data []byte) (e error) {
7776
// Write packet
7877
n, e := mc.netConn.Write(pktData)
7978
if e != nil || n != len(pktData) {
80-
fmt.Println("BadConn:", e)
8179
e = driver.ErrBadConn
8280
return
8381
}
@@ -93,7 +91,6 @@ func (mc *mysqlConn) readNumber(n uint8) (num uint64, e error) {
9391

9492
nr, err := io.ReadFull(mc.netConn, buf)
9593
if err != nil || nr != int(n) {
96-
fmt.Println(e)
9794
e = driver.ErrBadConn
9895
return
9996
}
@@ -141,7 +138,10 @@ func (mc *mysqlConn) readInitPacket() (e error) {
141138
// Protocol version [8 bit uint]
142139
mc.server.protocol = data[pos]
143140
if mc.server.protocol < MIN_PROTOCOL_VERSION {
144-
e = errors.New(fmt.Sprintf("Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required", mc.server.protocol, MIN_PROTOCOL_VERSION))
141+
e = fmt.Errorf(
142+
"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
143+
mc.server.protocol,
144+
MIN_PROTOCOL_VERSION)
145145
}
146146
pos++
147147

@@ -275,24 +275,24 @@ func (mc *mysqlConn) writeCommandPacket(command commandType, args ...interface{}
275275
// Commands without args
276276
case COM_QUIT, COM_PING:
277277
if len(args) > 0 {
278-
return errors.New(fmt.Sprintf("Too much arguments (Got: %d Has:0)", len(args)))
278+
return fmt.Errorf("Too much arguments (Got: %d Has:0)", len(args))
279279
}
280280

281281
// Commands with 1 arg unterminated string
282282
case COM_QUERY, COM_STMT_PREPARE:
283283
if len(args) != 1 {
284-
return errors.New(fmt.Sprintf("Invalid arguments count (Got:%d Need: 1)", len(args)))
284+
return fmt.Errorf("Invalid arguments count (Got:%d Need: 1)", len(args))
285285
}
286286
data = append(data, []byte(args[0].(string))...)
287287

288288
// Commands with 1 arg 32 bit uint
289289
case COM_STMT_CLOSE:
290290
if len(args) != 1 {
291-
return errors.New(fmt.Sprintf("Invalid arguments count (Got:%d Need: 1)", len(args)))
291+
return fmt.Errorf("Invalid arguments count (Got:%d Need: 1)", len(args))
292292
}
293293
data = append(data, uint32ToBytes(args[0].(uint32))...)
294294
default:
295-
return errors.New(fmt.Sprintf("Unknown command: %d", command))
295+
return fmt.Errorf("Unknown command: %d", command)
296296
}
297297

298298
// Send CMD packet
@@ -448,7 +448,7 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
448448
// EOF Packet
449449
if data[0] == 254 && len(data) == 5 {
450450
if len(columns) != n {
451-
e = errors.New(fmt.Sprintf("ColumnsCount mismatch n:%d len:%d", n, len(columns)))
451+
e = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", n, len(columns))
452452
}
453453
return
454454
}
@@ -716,7 +716,6 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
716716
// Check for NULL fields
717717
for i = 0; i < stmt.paramCount; i++ {
718718
if (*args)[i] == nil {
719-
fmt.Println("nil", i, (*args)[i])
720719
bitMask += 1 << uint(i)
721720
}
722721
}
@@ -728,100 +727,79 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
728727
// append nullBitMap [(param_count+7)/8 bytes]
729728
data = append(data, nullBitMap...)
730729

731-
// Check for changed Params
732-
newParamsBound := true
733-
if stmt.args != nil {
734-
for i := 0; i < len(*args); i++ {
735-
if (*args)[i] != (*stmt.args)[i] {
736-
fmt.Println((*args)[i], "!=", (*stmt.args)[i])
737-
newParamsBound = false
738-
break
739-
}
730+
// newParameterBoundFlag 1 [1 byte]
731+
data = append(data, byte(1))
732+
733+
// append types and cache values
734+
paramValues := make([]byte, 0)
735+
var pv reflect.Value
736+
for i = 0; i < stmt.paramCount; i++ {
737+
switch (*args)[i].(type) {
738+
case nil:
739+
data = append(data, []byte{
740+
byte(FIELD_TYPE_NULL),
741+
0x0}...)
742+
continue
743+
744+
case []byte:
745+
data = append(data, []byte{
746+
byte(FIELD_TYPE_STRING),
747+
0x0}...)
748+
val := (*args)[i].([]byte)
749+
paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
750+
paramValues = append(paramValues, val...)
751+
continue
752+
753+
case time.Time:
754+
// Format to string for time+date Fields
755+
// Data is packed in case reflect.String below
756+
(*args)[i] = (*args)[i].(time.Time).Format(TIME_FORMAT)
740757
}
741-
}
742758

743-
// No (new) Parameters bound or rebound
744-
if !newParamsBound {
745-
//newParameterBoundFlag 0 [1 byte]
746-
data = append(data, byte(0))
747-
} else {
748-
// newParameterBoundFlag 1 [1 byte]
749-
data = append(data, byte(1))
750-
751-
// append types and cache values
752-
paramValues := make([]byte, 0)
753-
var pv reflect.Value
754-
for i = 0; i < stmt.paramCount; i++ {
755-
switch (*args)[i].(type) {
756-
case nil:
757-
data = append(data, []byte{
758-
byte(FIELD_TYPE_NULL),
759-
0x0}...)
760-
continue
761-
762-
case []byte:
763-
data = append(data, []byte{
764-
byte(FIELD_TYPE_STRING),
765-
0x0}...)
766-
val := (*args)[i].([]byte)
767-
paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
768-
paramValues = append(paramValues, val...)
769-
continue
770-
771-
case time.Time:
772-
// Format to string for time+date Fields
773-
// Data is packed in case reflect.String below
774-
(*args)[i] = (*args)[i].(time.Time).Format(TIME_FORMAT)
775-
}
759+
pv = reflect.ValueOf((*args)[i])
760+
switch pv.Kind() {
761+
case reflect.Int64:
762+
data = append(data, []byte{
763+
byte(FIELD_TYPE_LONGLONG),
764+
0x0}...)
765+
paramValues = append(paramValues, int64ToBytes(pv.Int())...)
766+
continue
776767

777-
pv = reflect.ValueOf((*args)[i])
778-
switch pv.Kind() {
779-
case reflect.Int64:
780-
data = append(data, []byte{
781-
byte(FIELD_TYPE_LONGLONG),
782-
0x0}...)
783-
paramValues = append(paramValues, int64ToBytes(pv.Int())...)
784-
continue
785-
786-
case reflect.Float64:
787-
data = append(data, []byte{
788-
byte(FIELD_TYPE_DOUBLE),
789-
0x0}...)
790-
paramValues = append(paramValues, float64ToBytes(pv.Float())...)
791-
continue
792-
793-
case reflect.Bool:
794-
data = append(data, []byte{
795-
byte(FIELD_TYPE_TINY),
796-
0x0}...)
797-
val := pv.Bool()
798-
if val {
799-
paramValues = append(paramValues, byte(1))
800-
} else {
801-
paramValues = append(paramValues, byte(0))
802-
}
803-
continue
804-
805-
case reflect.String:
806-
data = append(data, []byte{
807-
byte(FIELD_TYPE_STRING),
808-
0x0}...)
809-
val := pv.String()
810-
paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
811-
paramValues = append(paramValues, []byte(val)...)
812-
continue
813-
814-
default:
815-
return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
768+
case reflect.Float64:
769+
data = append(data, []byte{
770+
byte(FIELD_TYPE_DOUBLE),
771+
0x0}...)
772+
paramValues = append(paramValues, float64ToBytes(pv.Float())...)
773+
continue
774+
775+
case reflect.Bool:
776+
data = append(data, []byte{
777+
byte(FIELD_TYPE_TINY),
778+
0x0}...)
779+
val := pv.Bool()
780+
if val {
781+
paramValues = append(paramValues, byte(1))
782+
} else {
783+
paramValues = append(paramValues, byte(0))
816784
}
817-
}
785+
continue
818786

819-
// append cached values
820-
data = append(data, paramValues...)
787+
case reflect.String:
788+
data = append(data, []byte{
789+
byte(FIELD_TYPE_STRING),
790+
0x0}...)
791+
val := pv.String()
792+
paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
793+
paramValues = append(paramValues, []byte(val)...)
794+
continue
795+
796+
default:
797+
return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
798+
}
821799
}
822800

823-
// Save args
824-
stmt.args = args
801+
// append cached values
802+
data = append(data, paramValues...)
825803
}
826804
return stmt.mc.writePacket(data)
827805
}

statement.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ type stmtContent struct {
1818
query string
1919
paramCount int
2020
params []*mysqlField
21-
args *[]driver.Value
22-
newParamsBound bool
2321
}
2422

2523
type mysqlStmt struct {
@@ -55,11 +53,13 @@ func (stmt mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
5553
}
5654

5755
if resLen > 0 {
56+
// Columns
5857
_, e = stmt.mc.readUntilEOF()
5958
if e != nil {
6059
return nil, e
6160
}
6261

62+
// Rows
6363
stmt.mc.affectedRows, e = stmt.mc.readUntilEOF()
6464
if e != nil {
6565
return nil, e

0 commit comments

Comments
 (0)