Skip to content

Commit e4b1048

Browse files
committed
Added type-conversion for time.Time, []byte and float64 in stmt-params
+ even more clean up
1 parent 51f0124 commit e4b1048

File tree

6 files changed

+81
-55
lines changed

6 files changed

+81
-55
lines changed

connection.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,18 @@ func (mc *mysqlConn) Prepare(query string) (ds driver.Stmt, e error) {
157157
stmt.mc = mc
158158

159159
// Read Result
160-
var columnCount, paramCount uint16
161-
stmt.id, columnCount, paramCount, e = mc.readPrepareResultPacket()
160+
var columnCount uint16
161+
columnCount, e = stmt.readPrepareResultPacket()
162162
if e != nil {
163163
return
164164
}
165165

166-
if paramCount > 0 {
167-
stmt.params, e = stmt.mc.readColumns(int(paramCount))
166+
if stmt.paramCount > 0 {
167+
stmt.params, e = stmt.mc.readColumns(stmt.paramCount)
168168
if e != nil {
169169
return
170170
}
171171
}
172-
stmt.paramCount = int(paramCount)
173172

174173
if columnCount > 0 {
175174
_, e = stmt.mc.readColumns(int(columnCount))

const.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ package mysql
1414
const (
1515
MIN_PROTOCOL_VERSION = 10
1616
MAX_PACKET_SIZE = 1<<24 - 1
17-
TIME_FORMAT = "2006-01-02 15:04:05.000000000"
17+
TIME_FORMAT = "2006-01-02 15:04:05"
1818
)
1919

2020
type ClientFlag uint32

packets.go

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ import (
1717
"time"
1818
)
1919

20+
// Packets documentation:
21+
// http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
22+
2023
// Read packet to buffer 'data'
2124
func (mc *mysqlConn) readPacket() (data []byte, e error) {
2225
// Packet Length
@@ -366,7 +369,7 @@ n (until end of packet) message
366369
*/
367370
func (mc *mysqlConn) handleOkPacket(data []byte) (e error) {
368371
if data[0] != 0 {
369-
e = errors.New("Wrong Packet-Type: Not a OK-Packet")
372+
e = errors.New("Wrong Packet-Type: Not an OK-Packet")
370373
return
371374
}
372375

@@ -451,32 +454,37 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
451454
}
452455

453456
var pos, n int
454-
var catalog, database, table, orgTable, name, orgName []byte
455-
var defaultVal uint64
457+
var name []byte
458+
//var catalog, database, table, orgTable, name, orgName []byte
459+
//var defaultVal uint64
456460

457461
// Catalog
458-
catalog, n, _, e = readLengthCodedBinary(data)
462+
//catalog, n, _, e = readLengthCodedBinary(data)
463+
n, e = readAndDropLengthCodedBinary(data)
459464
if e != nil {
460465
return
461466
}
462467
pos += n
463468

464469
// Database [len coded string]
465-
database, n, _, e = readLengthCodedBinary(data[pos:])
470+
//database, n, _, e = readLengthCodedBinary(data[pos:])
471+
n, e = readAndDropLengthCodedBinary(data[pos:])
466472
if e != nil {
467473
return
468474
}
469475
pos += n
470476

471477
// Table [len coded string]
472-
table, n, _, e = readLengthCodedBinary(data[pos:])
478+
//table, n, _, e = readLengthCodedBinary(data[pos:])
479+
n, e = readAndDropLengthCodedBinary(data[pos:])
473480
if e != nil {
474481
return
475482
}
476483
pos += n
477484

478485
// Original table [len coded string]
479-
orgTable, n, _, e = readLengthCodedBinary(data[pos:])
486+
//orgTable, n, _, e = readLengthCodedBinary(data[pos:])
487+
n, e = readAndDropLengthCodedBinary(data[pos:])
480488
if e != nil {
481489
return
482490
}
@@ -490,7 +498,8 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
490498
pos += n
491499

492500
// Original name [len coded string]
493-
orgName, n, _, e = readLengthCodedBinary(data[pos:])
501+
//orgName, n, _, e = readLengthCodedBinary(data[pos:])
502+
n, e = readAndDropLengthCodedBinary(data[pos:])
494503
if e != nil {
495504
return
496505
}
@@ -500,11 +509,11 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
500509
pos++
501510

502511
// Charset [16 bit uint]
503-
charsetNumber := bytesToUint16(data[pos : pos+2])
512+
//charsetNumber := bytesToUint16(data[pos : pos+2])
504513
pos += 2
505514

506515
// Length [32 bit uint]
507-
length := bytesToUint32(data[pos : pos+4])
516+
//length := bytesToUint32(data[pos : pos+4])
508517
pos += 4
509518

510519
// Field type [byte]
@@ -513,18 +522,16 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
513522

514523
// Flags [16 bit uint]
515524
flags := FieldFlag(bytesToUint16(data[pos : pos+2]))
516-
pos += 2
525+
//pos += 2
517526

518527
// Decimals [8 bit uint]
519-
decimals := data[pos]
520-
pos++
528+
//decimals := data[pos]
529+
//pos++
521530

522531
// Default value [len coded binary]
523-
if pos < len(data) {
524-
defaultVal, _, e = bytesToLengthCodedBinary(data[pos:])
525-
}
526-
527-
fmt.Printf("catalog=%s database=%s table=%s orgTable=%s name=%s orgName=%s charsetNumber=%d length=%d fieldType=%d flags=%d decimals=%d defaultVal=%d \n", catalog, database, table, orgTable, name, orgName, charsetNumber, length, fieldType, flags, decimals, defaultVal)
532+
//if pos < len(data) {
533+
// defaultVal, _, e = bytesToLengthCodedBinary(data[pos:])
534+
//}
528535

529536
columns = append(columns, &mysqlField{name: string(name), fieldType: fieldType, flags: flags})
530537
}
@@ -628,8 +635,8 @@ Prepare OK Packet
628635
(EOF packet)
629636
630637
*/
631-
func (mc *mysqlConn) readPrepareResultPacket() (stmtID uint32, columnCount uint16, paramCount uint16, e error) {
632-
data, e := mc.readPacket()
638+
func (stmt mysqlStmt) readPrepareResultPacket() (columnCount uint16, e error) {
639+
data, e := stmt.mc.readPacket()
633640
if e != nil {
634641
return
635642
}
@@ -638,20 +645,20 @@ func (mc *mysqlConn) readPrepareResultPacket() (stmtID uint32, columnCount uint1
638645
pos := 0
639646

640647
if data[pos] != 0 {
641-
e = mc.handleErrorPacket(data)
648+
e = stmt.mc.handleErrorPacket(data)
642649
return
643650
}
644651
pos++
645652

646-
stmtID = bytesToUint32(data[pos : pos+4])
653+
stmt.id = bytesToUint32(data[pos : pos+4])
647654
pos += 4
648655

649656
// Column count [16 bit uint]
650657
columnCount = bytesToUint16(data[pos : pos+2])
651658
pos += 2
652659

653660
// Param count [16 bit uint]
654-
paramCount = bytesToUint16(data[pos : pos+2])
661+
stmt.paramCount = int(bytesToUint16(data[pos : pos+2]))
655662
pos += 2
656663

657664
// Warning count [16 bit uint]
@@ -751,10 +758,20 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
751758
byte(FIELD_TYPE_NULL),
752759
0x0}...)
753760
continue
761+
754762
case []byte:
755-
fmt.Println("[]byte", (*args)[i])
763+
data = append(data, []byte{
764+
byte(FIELD_TYPE_STRING),
765+
0x0}...)
766+
val := (*args)[i].([]byte)
767+
paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
768+
paramValues = append(paramValues, val...)
769+
continue
770+
756771
case time.Time:
757-
fmt.Println("time.Time", (*args)[i])
772+
// Format to string for time+date Fields
773+
// Data is packed in case reflect.String below
774+
(*args)[i] = (*args)[i].(time.Time).Format(TIME_FORMAT)
758775
}
759776

760777
pv = reflect.ValueOf((*args)[i])
@@ -764,10 +781,14 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
764781
byte(FIELD_TYPE_LONGLONG),
765782
0x0}...)
766783
paramValues = append(paramValues, int64ToBytes(pv.Int())...)
767-
fmt.Println("int64", (*args)[i])
784+
continue
768785

769786
case reflect.Float64:
770-
fmt.Println("float64", (*args)[i])
787+
data = append(data, []byte{
788+
byte(FIELD_TYPE_DOUBLE),
789+
0x0}...)
790+
paramValues = append(paramValues, float64ToBytes(pv.Float())...)
791+
continue
771792

772793
case reflect.Bool:
773794
data = append(data, []byte{
@@ -779,7 +800,7 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
779800
} else {
780801
paramValues = append(paramValues, byte(0))
781802
}
782-
fmt.Println("bool", (*args)[i])
803+
continue
783804

784805
case reflect.String:
785806
data = append(data, []byte{
@@ -788,7 +809,7 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
788809
val := pv.String()
789810
paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
790811
paramValues = append(paramValues, []byte(val)...)
791-
fmt.Println("string", string([]byte(val)))
812+
continue
792813

793814
default:
794815
return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
@@ -797,7 +818,6 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
797818

798819
// append cached values
799820
data = append(data, paramValues...)
800-
fmt.Println("data", string(data))
801821
}
802822

803823
// Save args
@@ -855,7 +875,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
855875
row[i] = intToByteStr(int64(int8(byteToUint8(data[pos]))))
856876
}
857877
pos++
858-
fmt.Println("TINY", string(*row[i]))
859878

860879
case FIELD_TYPE_SHORT, FIELD_TYPE_YEAR:
861880
if unsigned {
@@ -864,7 +883,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
864883
row[i] = intToByteStr(int64(int16(bytesToUint16(data[pos : pos+2]))))
865884
}
866885
pos += 2
867-
fmt.Println("SHORT", string(*row[i]))
868886

869887
case FIELD_TYPE_INT24, FIELD_TYPE_LONG:
870888
if unsigned {
@@ -873,7 +891,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
873891
row[i] = intToByteStr(int64(int32(bytesToUint32(data[pos : pos+4]))))
874892
}
875893
pos += 4
876-
fmt.Println("LONG", string(*row[i]))
877894

878895
case FIELD_TYPE_LONGLONG:
879896
if unsigned {
@@ -882,17 +899,14 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
882899
row[i] = intToByteStr(int64(bytesToUint64(data[pos : pos+8])))
883900
}
884901
pos += 8
885-
fmt.Println("LONGLONG", string(*row[i]))
886902

887903
case FIELD_TYPE_FLOAT:
888904
row[i] = float32ToByteStr(bytesToFloat32(data[pos : pos+4]))
889905
pos += 4
890-
fmt.Println("FLOAT", string(*row[i]))
891906

892907
case FIELD_TYPE_DOUBLE:
893908
row[i] = float64ToByteStr(bytesToFloat64(data[pos : pos+8]))
894909
pos += 8
895-
fmt.Println("DOUBLE", string(*row[i]))
896910

897911
case FIELD_TYPE_DECIMAL, FIELD_TYPE_NEWDECIMAL:
898912
var tmp []byte
@@ -903,10 +917,8 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
903917

904918
if isNull && rc.columns[i].flags&FLAG_NOT_NULL == 0 {
905919
row[i] = nil
906-
fmt.Println("DECIMAL", nil)
907920
} else {
908921
row[i] = &tmp
909-
fmt.Println("DECIMAL", string(tmp))
910922
}
911923
pos += n
912924

@@ -923,10 +935,8 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
923935

924936
if isNull && rc.columns[i].flags&FLAG_NOT_NULL == 0 {
925937
row[i] = nil
926-
fmt.Println("STRING", nil)
927938
} else {
928939
row[i] = &tmp
929-
fmt.Println("STRING", string(tmp))
930940
}
931941
pos += n
932942

@@ -950,7 +960,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
950960
}
951961
row[i] = &tmp
952962
pos += int(num)
953-
fmt.Println("DATE", string(*row[i]))
954963

955964
// Time HH:MM:SS
956965
case FIELD_TYPE_TIME:
@@ -971,7 +980,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
971980
}
972981
row[i] = &tmp
973982
pos += n + int(num)
974-
fmt.Println("TIME", string(*row[i]))
975983

976984
// Timestamp YYYY-MM-DD HH:MM:SS
977985
case FIELD_TYPE_TIMESTAMP, FIELD_TYPE_DATETIME:
@@ -996,7 +1004,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
9961004
}
9971005
row[i] = &tmp
9981006
pos += int(num)
999-
fmt.Println("DATE", string(*row[i]))
10001007

10011008
// Please report if this happens!
10021009
default:

rows.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (rows mysqlRows) Close() error {
4242
}
4343

4444
// Next returns []driver.Value filled with either nil values for NULL entries
45-
// or []byte for every other entries. Type conversion is done on rows.scan(),
45+
// or []byte's for all other entries. Type conversion is done on rows.scan(),
4646
// when the dest. type is know, which makes type conversion easier and avoids
4747
// unnecessary conversions.
4848
func (rows mysqlRows) Next(dest []driver.Value) error {

statement.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ package mysql
1010

1111
import (
1212
"database/sql/driver"
13-
"fmt"
1413
)
1514

1615
type stmtContent struct {
@@ -118,7 +117,7 @@ func (stmt mysqlStmt) Query(args []driver.Value) (dr driver.Rows, e error) {
118117
// column index. If the type of a specific column isn't known
119118
// or shouldn't be handled specially, DefaultValueConverter
120119
// can be returned.
121-
func (stmt mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
122-
debug(fmt.Sprintf("ColumnConverter(%d)", idx))
123-
return driver.DefaultParameterConverter
124-
}
120+
//func (stmt mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
121+
// debug(fmt.Sprintf("ColumnConverter(%d)", idx))
122+
// return driver.DefaultParameterConverter
123+
//}

utils.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,23 @@ func readLengthCodedBinary(data []byte) (b []byte, n int, isNull bool, e error)
145145
return
146146
}
147147

148+
func readAndDropLengthCodedBinary(data []byte) (n int, e error) {
149+
// Get length
150+
num, n, e := bytesToLengthCodedBinary(data)
151+
if e != nil {
152+
return
153+
}
154+
155+
// Check data length
156+
if len(data) < n+int(num) {
157+
e = io.EOF
158+
return
159+
}
160+
161+
n += int(num)
162+
return
163+
}
164+
148165
/******************************************************************************
149166
* Convert from and to bytes *
150167
******************************************************************************/
@@ -210,6 +227,10 @@ func bytesToFloat64(b []byte) float64 {
210227
return math.Float64frombits(bytesToUint64(b))
211228
}
212229

230+
func float64ToBytes(f float64) []byte {
231+
return uint64ToBytes(math.Float64bits(f))
232+
}
233+
213234
func bytesToLengthCodedBinary(b []byte) (length uint64, n int, e error) {
214235
switch {
215236

0 commit comments

Comments
 (0)