Skip to content

Commit 401f0ed

Browse files
methaness49919201
authored andcommitted
Parse numbers on text protocol too (go-sql-driver#1452)
1 parent f20b286 commit 401f0ed

File tree

2 files changed

+90
-36
lines changed

2 files changed

+90
-36
lines changed

driver_test.go

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -149,29 +149,18 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
149149
defer db2.Close()
150150
}
151151

152-
dsn3 := dsn + "&multiStatements=true"
153-
var db3 *sql.DB
154-
if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
155-
db3, err = sql.Open("mysql", dsn3)
156-
if err != nil {
157-
t.Fatalf("error connecting: %s", err.Error())
158-
}
159-
defer db3.Close()
160-
}
161-
162-
dbt := &DBTest{t, db}
163-
dbt2 := &DBTest{t, db2}
164-
dbt3 := &DBTest{t, db3}
165152
for _, test := range tests {
166-
test(dbt)
167-
dbt.db.Exec("DROP TABLE IF EXISTS test")
153+
t.Run("default", func(t *testing.T) {
154+
dbt := &DBTest{t, db}
155+
test(dbt)
156+
dbt.db.Exec("DROP TABLE IF EXISTS test")
157+
})
168158
if db2 != nil {
169-
test(dbt2)
170-
dbt2.db.Exec("DROP TABLE IF EXISTS test")
171-
}
172-
if db3 != nil {
173-
test(dbt3)
174-
dbt3.db.Exec("DROP TABLE IF EXISTS test")
159+
t.Run("interpolateParams", func(t *testing.T) {
160+
dbt2 := &DBTest{t, db2}
161+
test(dbt2)
162+
dbt2.db.Exec("DROP TABLE IF EXISTS test")
163+
})
175164
}
176165
}
177166
}
@@ -317,6 +306,48 @@ func TestCRUD(t *testing.T) {
317306
})
318307
}
319308

309+
// TestNumbers test that selecting numeric columns.
310+
// Both of textRows and binaryRows should return same type and value.
311+
func TestNumbersToAny(t *testing.T) {
312+
runTests(t, dsn, func(dbt *DBTest) {
313+
dbt.mustExec("CREATE TABLE `test` (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " +
314+
"i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE)")
315+
dbt.mustExec("INSERT INTO `test` VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5)")
316+
317+
// Use binaryRows for intarpolateParams=false and textRows for intarpolateParams=true.
318+
rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64 FROM `test` WHERE id=?", 1)
319+
if !rows.Next() {
320+
dbt.Fatal("no data")
321+
}
322+
var b, i8, i16, i32, i64, f32, f64 any
323+
err := rows.Scan(&b, &i8, &i16, &i32, &i64, &f32, &f64)
324+
if err != nil {
325+
dbt.Fatal(err)
326+
}
327+
if b.(int64) != 1 {
328+
dbt.Errorf("b != 1")
329+
}
330+
if i8.(int64) != 127 {
331+
dbt.Errorf("i8 != 127")
332+
}
333+
if i16.(int64) != 32767 {
334+
dbt.Errorf("i16 != 32767")
335+
}
336+
if i32.(int64) != 2147483647 {
337+
dbt.Errorf("i32 != 2147483647")
338+
}
339+
if i64.(int64) != 9223372036854775807 {
340+
dbt.Errorf("i64 != 9223372036854775807")
341+
}
342+
if f32.(float32) != 1.25 {
343+
dbt.Errorf("f32 != 1.25")
344+
}
345+
if f64.(float64) != 2.5 {
346+
dbt.Errorf("f64 != 2.5")
347+
}
348+
})
349+
}
350+
320351
func TestMultiQuery(t *testing.T) {
321352
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
322353
// Create Table
@@ -1807,13 +1838,13 @@ func TestConcurrent(t *testing.T) {
18071838
}
18081839

18091840
runTests(t, dsn, func(dbt *DBTest) {
1810-
var version string
1811-
if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil {
1812-
dbt.Fatalf("%s", err.Error())
1813-
}
1814-
if strings.Contains(strings.ToLower(version), "mariadb") {
1815-
t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`)
1816-
}
1841+
// var version string
1842+
// if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil {
1843+
// dbt.Fatal(err)
1844+
// }
1845+
// if strings.Contains(strings.ToLower(version), "mariadb") {
1846+
// t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`)
1847+
// }
18171848

18181849
var max int
18191850
err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)

packets.go

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"fmt"
1919
"io"
2020
"math"
21+
"strconv"
2122
"time"
2223
)
2324

@@ -769,7 +770,8 @@ func (rows *textRows) readRow(dest []driver.Value) error {
769770

770771
for i := range dest {
771772
// Read bytes and convert to string
772-
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
773+
var buf []byte
774+
buf, isNull, n, err = readLengthEncodedString(data[pos:])
773775
pos += n
774776

775777
if err != nil {
@@ -781,19 +783,40 @@ func (rows *textRows) readRow(dest []driver.Value) error {
781783
continue
782784
}
783785

784-
if !mc.parseTime {
785-
continue
786-
}
787-
788-
// Parse time field
789786
switch rows.rs.columns[i].fieldType {
790787
case fieldTypeTimestamp,
791788
fieldTypeDateTime,
792789
fieldTypeDate,
793790
fieldTypeNewDate:
794-
if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil {
795-
return err
791+
if mc.parseTime {
792+
dest[i], err = parseDateTime(buf, mc.cfg.Loc)
793+
} else {
794+
dest[i] = buf
795+
}
796+
797+
case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
798+
dest[i], err = strconv.ParseInt(string(buf), 10, 32)
799+
800+
case fieldTypeLongLong:
801+
if rows.rs.columns[i].flags&flagUnsigned != 0 {
802+
dest[i], err = strconv.ParseUint(string(buf), 10, 64)
803+
} else {
804+
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
796805
}
806+
807+
case fieldTypeFloat:
808+
var d float64
809+
d, err = strconv.ParseFloat(string(buf), 32)
810+
dest[i] = float32(d)
811+
812+
case fieldTypeDouble:
813+
dest[i], err = strconv.ParseFloat(string(buf), 64)
814+
815+
default:
816+
dest[i] = buf
817+
}
818+
if err != nil {
819+
return err
797820
}
798821
}
799822

0 commit comments

Comments
 (0)