diff --git a/driver_test.go b/driver_test.go index a52cc5cd0..f9cf1ebb6 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1245,6 +1245,47 @@ func TestTimezoneConversion(t *testing.T) { } } +func TestSelectFloatToFloat64(t *testing.T) { + createTest := func(query string, args ...interface{}) func(dbt *DBTest) { + return func(dbt *DBTest) { + v := float64(0.050551) + + // Create table + dbt.mustExec("CREATE TABLE test (f FLOAT)") + dbt.mustExec("INSERT INTO test VALUE (?)", v) + + // Retrieve + rows := dbt.mustQuery(query, args...) + if !rows.Next() { + dbt.Fatal("Didn't get any rows out") + } + + var f float64 + err := rows.Scan(&f) + if err != nil { + dbt.Fatal("Err", err) + } + + // Check that dates match + if f != v { + dbt.Errorf("Float values don't match.\n") + dbt.Errorf(" Inserted: %v\n", v) + dbt.Errorf(" Selected: %v\n", f) + } + } + } + + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, createTest("SELECT f FROM test")) // not prepared statement + runTests(t, testdsn, createTest("SELECT f FROM test WHERE 1=?", 1)) // prepared statement + runTests(t, testdsn, createTest("SELECT IFNULL(f, 0) f FROM test")) // not prepared statement with IFNULL + } +} + // Special cases func TestRowsClose(t *testing.T) { diff --git a/packets.go b/packets.go index 255301e81..b2cebef3c 100644 --- a/packets.go +++ b/packets.go @@ -16,6 +16,7 @@ import ( "fmt" "io" "math" + "strconv" "time" ) @@ -616,22 +617,31 @@ func (rows *textRows) readRow(dest []driver.Value) error { pos += n if err == nil { if !isNull { - if !mc.parseTime { + // @todo hacky fix, please make it better + if i > len(rows.columns)-1 { continue - } else { - switch rows.columns[i].fieldType { - case fieldTypeTimestamp, fieldTypeDateTime, - fieldTypeDate, fieldTypeNewDate: - dest[i], err = parseDateTime( - string(dest[i].([]byte)), - mc.cfg.loc, - ) - if err == nil { - continue - } - default: + } + switch rows.columns[i].fieldType { + case fieldTypeTimestamp, fieldTypeDateTime, + fieldTypeDate, fieldTypeNewDate: + if !mc.parseTime { + continue + } + dest[i], err = parseDateTime( + string(dest[i].([]byte)), + mc.cfg.loc, + ) + if err == nil { + continue + } + case fieldTypeFloat: + val, err := strconv.ParseFloat(string(dest[i].([]byte)), 32) + dest[i] = float32(val) + if err == nil { continue } + default: + continue } } else { @@ -1037,7 +1047,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeFloat: - dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) + dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) pos += 4 continue