diff --git a/server/stmt.go b/server/stmt.go index 56c80c507..cc99b40ba 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -155,10 +155,10 @@ func (c *Conn) handleStmtExecute(data []byte) (*mysql.Result, error) { pos += paramNum << 1 paramValues = data[pos:] - } - if err := c.bindStmtArgs(s, nullBitmaps, paramTypes, paramValues); err != nil { - return nil, errors.Trace(err) + if err := c.bindStmtArgs(s, nullBitmaps, paramTypes, paramValues); err != nil { + return nil, errors.Trace(err) + } } } @@ -176,6 +176,14 @@ func (c *Conn) handleStmtExecute(data []byte) (*mysql.Result, error) { func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) error { args := s.Args + // Every param should have a type-and-flag of 2 bytes + // 0xfe80 == Type 0xfe and Flag 0x80 + // The flag only has one bit and that indicates if it is unsigned or not. + // Types are 1 byte, but might grow into the 7 unused bits in the future. + if len(paramTypes)/2 != s.Params { + return mysql.ErrMalformPacket + } + pos := 0 var v []byte @@ -190,7 +198,7 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) } tp := paramTypes[i<<1] - isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 + isUnsigned := (paramTypes[(i<<1)+1] & mysql.PARAM_UNSIGNED) > 0 switch tp { case mysql.MYSQL_TYPE_NULL: