diff --git a/driver_test.go b/driver_test.go index 224a24c53..7877aa979 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1669,8 +1669,9 @@ func TestStmtMultiRows(t *testing.T) { // Regression test for // * more than 32 NULL parameters (issue 209) // * more parameters than fit into the buffer (issue 201) +// * parameters * 64 > max_allowed_packet (issue 734) func TestPreparedManyCols(t *testing.T) { - const numParams = defaultBufSize + numParams := 65535 runTests(t, dsn, func(dbt *DBTest) { query := "SELECT ?" + strings.Repeat(",?", numParams-1) stmt, err := dbt.db.Prepare(query) @@ -1678,15 +1679,25 @@ func TestPreparedManyCols(t *testing.T) { dbt.Fatal(err) } defer stmt.Close() + // create more parameters than fit into the buffer // which will take nil-values params := make([]interface{}, numParams) rows, err := stmt.Query(params...) if err != nil { - stmt.Close() dbt.Fatal(err) } - defer rows.Close() + rows.Close() + + // Create 0byte string which we can't send via STMT_LONG_DATA. + for i := 0; i < numParams; i++ { + params[i] = "" + } + rows, err = stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() }) } diff --git a/packets.go b/packets.go index f63d25072..afd131424 100644 --- a/packets.go +++ b/packets.go @@ -912,6 +912,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc + // Determine threshould dynamically to avoid packet size shortage. + longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) + if longDataSize < 64 { + longDataSize = 64 + } + // Reset packet-sequence mc.sequence = 0 @@ -1039,7 +1045,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -1061,7 +1067,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), )