Skip to content

Commit 593ebcf

Browse files
committed
Merge pull request #210 from arnehormann/fix-many-cols
support prepared statements with more than 32 parameters
2 parents e56cf9d + 75c7231 commit 593ebcf

File tree

3 files changed

+58
-17
lines changed

3 files changed

+58
-17
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ New Features:
44

55
- Logging of critical errors is configurable with `SetLogger`
66

7+
Bugfixes:
8+
9+
- Allow more than 32 parameters in prepared statements
10+
11+
712
## Version 1.1 (2013-11-02)
813

914
Changes:

driver_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,30 @@ func TestStmtMultiRows(t *testing.T) {
12101210
})
12111211
}
12121212

1213+
// Regression test for
1214+
// * more than 32 NULL parameters (issue 209)
1215+
// * more parameters than fit into the buffer (issue 201)
1216+
func TestPreparedManyCols(t *testing.T) {
1217+
const numParams = defaultBufSize
1218+
runTests(t, dsn, func(dbt *DBTest) {
1219+
query := "SELECT ?" + strings.Repeat(",?", numParams-1)
1220+
stmt, err := dbt.db.Prepare(query)
1221+
if err != nil {
1222+
dbt.Fatal(err)
1223+
}
1224+
defer stmt.Close()
1225+
// create more parameters than fit into the buffer
1226+
// which will take nil-values
1227+
params := make([]interface{}, numParams)
1228+
rows, err := stmt.Query(params...)
1229+
if err != nil {
1230+
stmt.Close()
1231+
dbt.Fatal(err)
1232+
}
1233+
defer rows.Close()
1234+
})
1235+
}
1236+
12131237
func TestConcurrent(t *testing.T) {
12141238
if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
12151239
t.Skip("MYSQL_TEST_CONCURRENT env var not set")

packets.go

+29-17
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
750750
)
751751
}
752752

753+
const minPktLen = 4 + 1 + 4 + 1 + 4
753754
mc := stmt.mc
754755

755756
// Reset packet-sequence
@@ -758,7 +759,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
758759
var data []byte
759760

760761
if len(args) == 0 {
761-
data = mc.buf.takeBuffer(4 + 1 + 4 + 1 + 4)
762+
data = mc.buf.takeBuffer(minPktLen)
762763
} else {
763764
data = mc.buf.takeCompleteBuffer()
764765
}
@@ -787,34 +788,50 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
787788
data[13] = 0x00
788789

789790
if len(args) > 0 {
790-
// NULL-bitmap [(len(args)+7)/8 bytes]
791-
nullMask := uint64(0)
792-
793-
pos := 4 + 1 + 4 + 1 + 4 + ((len(args) + 7) >> 3)
791+
pos := minPktLen
792+
793+
var nullMask []byte
794+
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
795+
// buffer has to be extended but we don't know by how much so
796+
// we depend on append after all data with known sizes fit.
797+
// We stop at that because we deal with a lot of columns here
798+
// which makes the required allocation size hard to guess.
799+
tmp := make([]byte, pos+maskLen+typesLen)
800+
copy(tmp[:pos], data[:pos])
801+
data = tmp
802+
nullMask = data[pos : pos+maskLen]
803+
pos += maskLen
804+
} else {
805+
nullMask = data[pos : pos+maskLen]
806+
for i := 0; i < maskLen; i++ {
807+
nullMask[i] = 0
808+
}
809+
pos += maskLen
810+
}
794811

795812
// newParameterBoundFlag 1 [1 byte]
796813
data[pos] = 0x01
797814
pos++
798815

799816
// type of each parameter [len(args)*2 bytes]
800817
paramTypes := data[pos:]
801-
pos += (len(args) << 1)
818+
pos += len(args) * 2
802819

803820
// value of each parameter [n bytes]
804821
paramValues := data[pos:pos]
805822
valuesCap := cap(paramValues)
806823

807-
for i := range args {
824+
for i, arg := range args {
808825
// build NULL-bitmap
809-
if args[i] == nil {
810-
nullMask |= 1 << uint(i)
826+
if arg == nil {
827+
nullMask[i/8] |= 1 << (uint(i) & 7)
811828
paramTypes[i+i] = fieldTypeNULL
812829
paramTypes[i+i+1] = 0x00
813830
continue
814831
}
815832

816833
// cache types and values
817-
switch v := args[i].(type) {
834+
switch v := arg.(type) {
818835
case int64:
819836
paramTypes[i+i] = fieldTypeLongLong
820837
paramTypes[i+i+1] = 0x00
@@ -877,7 +894,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
877894
}
878895

879896
// Handle []byte(nil) as a NULL value
880-
nullMask |= 1 << uint(i)
897+
nullMask[i/8] |= 1 << (uint(i) & 7)
881898
paramTypes[i+i] = fieldTypeNULL
882899
paramTypes[i+i+1] = 0x00
883900

@@ -913,7 +930,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
913930
paramValues = append(paramValues, val...)
914931

915932
default:
916-
return fmt.Errorf("Can't convert type: %T", args[i])
933+
return fmt.Errorf("Can't convert type: %T", arg)
917934
}
918935
}
919936

@@ -926,11 +943,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
926943

927944
pos += len(paramValues)
928945
data = data[:pos]
929-
930-
// Convert nullMask to bytes
931-
for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ {
932-
data[i+14] = byte(nullMask >> uint(i<<3))
933-
}
934946
}
935947

936948
return mc.writePacket(data)

0 commit comments

Comments
 (0)