Skip to content

Commit 736b6fa

Browse files
authored
Stop ColumnTypeScanType() from returning sql.RawBytes (#1424)
ColumnTypeScanType() returns []byte, string, or sql.NullString. It returned sql.RawBytes but it was dangoerous. Fixes #1423
1 parent 0b40aee commit 736b6fa

File tree

2 files changed

+62
-53
lines changed

2 files changed

+62
-53
lines changed

Diff for: driver_test.go

+34-33
Original file line numberDiff line numberDiff line change
@@ -2778,13 +2778,18 @@ func TestRowsColumnTypes(t *testing.T) {
27782778
nd1 := sql.NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
27792779
nd2 := sql.NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
27802780
ndNULL := sql.NullTime{Time: time.Time{}, Valid: false}
2781-
rbNULL := sql.RawBytes(nil)
2782-
rb0 := sql.RawBytes("0")
2783-
rb42 := sql.RawBytes("42")
2784-
rbTest := sql.RawBytes("Test")
2785-
rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00
2786-
rbx0 := sql.RawBytes("\x00")
2787-
rbx42 := sql.RawBytes("\x42")
2781+
bNULL := []byte(nil)
2782+
nsNULL := sql.NullString{String: "", Valid: false}
2783+
// Helper function to build NullString from string literal.
2784+
ns := func(s string) sql.NullString { return sql.NullString{String: s, Valid: true} }
2785+
ns0 := ns("0")
2786+
b0 := []byte("0")
2787+
b42 := []byte("42")
2788+
nsTest := ns("Test")
2789+
bTest := []byte("Test")
2790+
b0pad4 := []byte("0\x00\x00\x00") // BINARY right-pads values with 0x00
2791+
bx0 := []byte("\x00")
2792+
bx42 := []byte("\x42")
27882793

27892794
var columns = []struct {
27902795
name string
@@ -2797,7 +2802,7 @@ func TestRowsColumnTypes(t *testing.T) {
27972802
valuesIn [3]string
27982803
valuesOut [3]interface{}
27992804
}{
2800-
{"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}},
2805+
{"bit8null", "BIT(8)", "BIT", scanTypeBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{bx0, bNULL, bx42}},
28012806
{"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}},
28022807
{"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}},
28032808
{"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
@@ -2817,24 +2822,24 @@ func TestRowsColumnTypes(t *testing.T) {
28172822
{"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
28182823
{"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}},
28192824
{"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
2820-
{"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}},
2821-
{"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}},
2822-
{"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}},
2823-
{"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}},
2824-
{"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}},
2825-
{"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}},
2826-
{"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2827-
{"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2828-
{"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}},
2829-
{"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2830-
{"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2831-
{"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2832-
{"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2833-
{"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
2834-
{"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2835-
{"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2836-
{"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2837-
{"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
2825+
{"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeString, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{"0.000000", "13.370000", "1234.123456"}},
2826+
{"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeNullString, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{ns("0.000000"), nsNULL, ns("1234.123456")}},
2827+
{"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeString, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{"0.0000", "13.3700", "1234.1235"}},
2828+
{"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeNullString, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{ns("0.0000"), nsNULL, ns("1234.1235")}},
2829+
{"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeString, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{"0", "13", "-12345"}},
2830+
{"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeNullString, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{ns0, nsNULL, ns("-12345")}},
2831+
{"char25null", "CHAR(25)", "CHAR", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}},
2832+
{"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}},
2833+
{"binary4null", "BINARY(4)", "BINARY", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0pad4, bNULL, bTest}},
2834+
{"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}},
2835+
{"tinyblobnull", "TINYBLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0, bNULL, bTest}},
2836+
{"tinytextnull", "TINYTEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}},
2837+
{"blobnull", "BLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0, bNULL, bTest}},
2838+
{"textnull", "TEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}},
2839+
{"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}},
2840+
{"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}},
2841+
{"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}},
2842+
{"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}},
28382843
{"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}},
28392844
{"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}},
28402845
{"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}},
@@ -2959,14 +2964,10 @@ func TestRowsColumnTypes(t *testing.T) {
29592964
if err != nil {
29602965
t.Fatalf("failed to scan values in %v", err)
29612966
}
2962-
for j := range values {
2963-
value := reflect.ValueOf(values[j]).Elem().Interface()
2967+
for j, value := range values {
2968+
value := reflect.ValueOf(value).Elem().Interface()
29642969
if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
2965-
if columns[j].scanType == scanTypeRawBytes {
2966-
t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
2967-
} else {
2968-
t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
2969-
}
2970+
t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
29702971
}
29712972
}
29722973
i++

Diff for: fields.go

+28-20
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,23 @@ func (mf *mysqlField) typeDatabaseName() string {
110110
}
111111

112112
var (
113-
scanTypeFloat32 = reflect.TypeOf(float32(0))
114-
scanTypeFloat64 = reflect.TypeOf(float64(0))
115-
scanTypeInt8 = reflect.TypeOf(int8(0))
116-
scanTypeInt16 = reflect.TypeOf(int16(0))
117-
scanTypeInt32 = reflect.TypeOf(int32(0))
118-
scanTypeInt64 = reflect.TypeOf(int64(0))
119-
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
120-
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
121-
scanTypeNullTime = reflect.TypeOf(sql.NullTime{})
122-
scanTypeUint8 = reflect.TypeOf(uint8(0))
123-
scanTypeUint16 = reflect.TypeOf(uint16(0))
124-
scanTypeUint32 = reflect.TypeOf(uint32(0))
125-
scanTypeUint64 = reflect.TypeOf(uint64(0))
126-
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
127-
scanTypeUnknown = reflect.TypeOf(new(interface{}))
113+
scanTypeFloat32 = reflect.TypeOf(float32(0))
114+
scanTypeFloat64 = reflect.TypeOf(float64(0))
115+
scanTypeInt8 = reflect.TypeOf(int8(0))
116+
scanTypeInt16 = reflect.TypeOf(int16(0))
117+
scanTypeInt32 = reflect.TypeOf(int32(0))
118+
scanTypeInt64 = reflect.TypeOf(int64(0))
119+
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
120+
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
121+
scanTypeNullTime = reflect.TypeOf(sql.NullTime{})
122+
scanTypeUint8 = reflect.TypeOf(uint8(0))
123+
scanTypeUint16 = reflect.TypeOf(uint16(0))
124+
scanTypeUint32 = reflect.TypeOf(uint32(0))
125+
scanTypeUint64 = reflect.TypeOf(uint64(0))
126+
scanTypeString = reflect.TypeOf("")
127+
scanTypeNullString = reflect.TypeOf(sql.NullString{})
128+
scanTypeBytes = reflect.TypeOf([]byte{})
129+
scanTypeUnknown = reflect.TypeOf(new(interface{}))
128130
)
129131

130132
type mysqlField struct {
@@ -187,12 +189,18 @@ func (mf *mysqlField) scanType() reflect.Type {
187189
}
188190
return scanTypeNullFloat
189191

192+
case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB,
193+
fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
194+
if mf.charSet == 63 /* binary */ {
195+
return scanTypeBytes
196+
}
197+
fallthrough
190198
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
191-
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
192-
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
193-
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
194-
fieldTypeTime:
195-
return scanTypeRawBytes
199+
fieldTypeEnum, fieldTypeSet, fieldTypeJSON, fieldTypeTime:
200+
if mf.flags&flagNotNULL != 0 {
201+
return scanTypeString
202+
}
203+
return scanTypeNullString
196204

197205
case fieldTypeDate, fieldTypeNewDate,
198206
fieldTypeTimestamp, fieldTypeDateTime:

0 commit comments

Comments
 (0)