@@ -750,6 +750,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
750
750
)
751
751
}
752
752
753
+ const minPktLen = 4 + 1 + 4 + 1 + 4
753
754
mc := stmt .mc
754
755
755
756
// Reset packet-sequence
@@ -758,7 +759,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
758
759
var data []byte
759
760
760
761
if len (args ) == 0 {
761
- data = mc .buf .takeBuffer (4 + 1 + 4 + 1 + 4 )
762
+ data = mc .buf .takeBuffer (minPktLen )
762
763
} else {
763
764
data = mc .buf .takeCompleteBuffer ()
764
765
}
@@ -787,34 +788,50 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
787
788
data [13 ] = 0x00
788
789
789
790
if len (args ) > 0 {
790
- // NULL-bitmap [(len(args)+7)/8 bytes]
791
- nullMask := uint64 (0 )
792
-
793
- pos := 4 + 1 + 4 + 1 + 4 + ((len (args ) + 7 ) >> 3 )
791
+ pos := minPktLen
792
+
793
+ var nullMask []byte
794
+ if maskLen , typesLen := (len (args )+ 7 )/ 8 , 1 + 2 * len (args ); pos + maskLen + typesLen >= len (data ) {
795
+ // buffer has to be extended but we don't know by how much so
796
+ // we depend on append after all data with known sizes fit.
797
+ // We stop at that because we deal with a lot of columns here
798
+ // which makes the required allocation size hard to guess.
799
+ tmp := make ([]byte , pos + maskLen + typesLen )
800
+ copy (tmp [:pos ], data [:pos ])
801
+ data = tmp
802
+ nullMask = data [pos : pos + maskLen ]
803
+ pos += maskLen
804
+ } else {
805
+ nullMask = data [pos : pos + maskLen ]
806
+ for i := 0 ; i < maskLen ; i ++ {
807
+ nullMask [i ] = 0
808
+ }
809
+ pos += maskLen
810
+ }
794
811
795
812
// newParameterBoundFlag 1 [1 byte]
796
813
data [pos ] = 0x01
797
814
pos ++
798
815
799
816
// type of each parameter [len(args)*2 bytes]
800
817
paramTypes := data [pos :]
801
- pos += ( len (args ) << 1 )
818
+ pos += len (args ) * 2
802
819
803
820
// value of each parameter [n bytes]
804
821
paramValues := data [pos :pos ]
805
822
valuesCap := cap (paramValues )
806
823
807
- for i := range args {
824
+ for i , arg := range args {
808
825
// build NULL-bitmap
809
- if args [ i ] == nil {
810
- nullMask |= 1 << uint (i )
826
+ if arg == nil {
827
+ nullMask [ i / 8 ] |= 1 << ( uint (i ) & 7 )
811
828
paramTypes [i + i ] = fieldTypeNULL
812
829
paramTypes [i + i + 1 ] = 0x00
813
830
continue
814
831
}
815
832
816
833
// cache types and values
817
- switch v := args [ i ] .(type ) {
834
+ switch v := arg .(type ) {
818
835
case int64 :
819
836
paramTypes [i + i ] = fieldTypeLongLong
820
837
paramTypes [i + i + 1 ] = 0x00
@@ -877,7 +894,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
877
894
}
878
895
879
896
// Handle []byte(nil) as a NULL value
880
- nullMask |= 1 << uint (i )
897
+ nullMask [ i / 8 ] |= 1 << ( uint (i ) & 7 )
881
898
paramTypes [i + i ] = fieldTypeNULL
882
899
paramTypes [i + i + 1 ] = 0x00
883
900
@@ -913,7 +930,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
913
930
paramValues = append (paramValues , val ... )
914
931
915
932
default :
916
- return fmt .Errorf ("Can't convert type: %T" , args [ i ] )
933
+ return fmt .Errorf ("Can't convert type: %T" , arg )
917
934
}
918
935
}
919
936
@@ -926,11 +943,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
926
943
927
944
pos += len (paramValues )
928
945
data = data [:pos ]
929
-
930
- // Convert nullMask to bytes
931
- for i , max := 0 , (stmt .paramCount + 7 )>> 3 ; i < max ; i ++ {
932
- data [i + 14 ] = byte (nullMask >> uint (i << 3 ))
933
- }
934
946
}
935
947
936
948
return mc .writePacket (data )
0 commit comments