diff --git a/src/ansi-c/c_typecheck_expr.cpp b/src/ansi-c/c_typecheck_expr.cpp index 17db22bb082..e99d3a1d136 100644 --- a/src/ansi-c/c_typecheck_expr.cpp +++ b/src/ansi-c/c_typecheck_expr.cpp @@ -290,30 +290,37 @@ void c_typecheck_baset::typecheck_expr_main(exprt &expr) else if(expr.id()==ID_forall || expr.id()==ID_exists) { - // op0 is a declaration, - // op1 the bound expression - expr.type()=bool_typet(); + // These have two operands. + // op0 is a tuple with declarations, + // op1 is the bound expression auto &binary_expr = to_binary_expr(expr); + auto &bindings = binary_expr.op0().operands(); + auto &where = binary_expr.op1(); - if(binary_expr.op0().get(ID_statement) != ID_decl) + for(const auto &binding : bindings) { - error().source_location = expr.source_location(); - error() << "expected declaration as operand of quantifier" << eom; - throw 0; + if(binding.get(ID_statement) != ID_decl) + { + error().source_location = expr.source_location(); + error() << "expected declaration as operand of quantifier" << eom; + throw 0; + } } - if(has_subexpr(binary_expr.op1(), ID_side_effect)) + if(has_subexpr(where, ID_side_effect)) { error().source_location = expr.source_location(); error() << "quantifier must not contain side effects" << eom; throw 0; } - // replace declaration by symbol expression - symbol_exprt bound = to_code_decl(to_code(binary_expr.op0())).symbol(); - binary_expr.op0().swap(bound); + expr.type() = bool_typet(); - implicit_typecast_bool(binary_expr.op1()); + // replace declarations by symbol expressions + for(auto &binding : bindings) + binding = to_code_decl(to_code(binding)).symbol(); + + implicit_typecast_bool(where); } else if(expr.id()==ID_label) { @@ -710,48 +717,55 @@ void c_typecheck_baset::typecheck_expr_operands(exprt &expr) } else if(expr.id()==ID_forall || expr.id()==ID_exists) { + // These introduce new symbols, which need to be added to the symbol table + // before the second operand is typechecked. + auto &binary_expr = to_binary_expr(expr); + auto &bindings = binary_expr.op0().operands(); - ansi_c_declarationt &declaration = to_ansi_c_declaration(binary_expr.op0()); + for(auto &binding : bindings) + { + ansi_c_declarationt &declaration = to_ansi_c_declaration(binding); - typecheck_declaration(declaration); + typecheck_declaration(declaration); - if(declaration.declarators().size()!=1) - { - error().source_location = expr.source_location(); - error() << "expected one declarator exactly" << eom; - throw 0; - } + if(declaration.declarators().size() != 1) + { + error().source_location = expr.source_location(); + error() << "forall/exists expects one declarator exactly" << eom; + throw 0; + } - irep_idt identifier= - declaration.declarators().front().get_name(); + irep_idt identifier = declaration.declarators().front().get_name(); - // look it up - symbol_tablet::symbolst::const_iterator s_it= - symbol_table.symbols.find(identifier); + // look it up + symbol_tablet::symbolst::const_iterator s_it = + symbol_table.symbols.find(identifier); - if(s_it==symbol_table.symbols.end()) - { - error().source_location = expr.source_location(); - error() << "failed to find decl symbol '" << identifier - << "' in symbol table" << eom; - throw 0; - } + if(s_it == symbol_table.symbols.end()) + { + error().source_location = expr.source_location(); + error() << "failed to find bound symbol `" << identifier + << "' in symbol table" << eom; + throw 0; + } - const symbolt &symbol=s_it->second; + const symbolt &symbol = s_it->second; - if(symbol.is_type || symbol.is_extern || symbol.is_static_lifetime || - !is_complete_type(symbol.type) || symbol.type.id()==ID_code) - { - error().source_location = expr.source_location(); - error() << "unexpected quantified symbol" << eom; - throw 0; - } + if( + symbol.is_type || symbol.is_extern || symbol.is_static_lifetime || + !is_complete_type(symbol.type) || symbol.type.id() == ID_code) + { + error().source_location = expr.source_location(); + error() << "unexpected quantified symbol" << eom; + throw 0; + } - code_declt decl(symbol.symbol_expr()); - decl.add_source_location()=declaration.source_location(); + code_declt decl(symbol.symbol_expr()); + decl.add_source_location() = declaration.source_location(); - binary_expr.op0() = decl; + binding = decl; + } typecheck_expr(binary_expr.op1()); } diff --git a/src/ansi-c/expr2c.cpp b/src/ansi-c/expr2c.cpp index bbd33ae2833..672d4616450 100644 --- a/src/ansi-c/expr2c.cpp +++ b/src/ansi-c/expr2c.cpp @@ -791,20 +791,21 @@ std::string expr2ct::convert_trinary( } std::string expr2ct::convert_quantifier( - const exprt &src, + const quantifier_exprt &src, const std::string &symbol, unsigned precedence) { - if(src.operands().size()!=2) + // our made-up syntax can only do one symbol + if(src.op0().operands().size() != 1) return convert_norep(src, precedence); unsigned p0, p1; - std::string op0=convert_with_precedence(src.op0(), p0); - std::string op1=convert_with_precedence(src.op1(), p1); + std::string op0 = convert_with_precedence(src.symbol(), p0); + std::string op1 = convert_with_precedence(src.where(), p1); std::string dest=symbol+" { "; - dest+=convert(src.op0().type()); + dest += convert(src.symbol().type()); dest+=" "+op0+"; "; dest+=op1; dest+=" }"; @@ -3720,13 +3721,16 @@ std::string expr2ct::convert_with_precedence( return convert_trinary(to_if_expr(src), "?", ":", precedence = 3); else if(src.id()==ID_forall) - return convert_quantifier(src, "forall", precedence=2); + return convert_quantifier( + to_quantifier_expr(src), "forall", precedence = 2); else if(src.id()==ID_exists) - return convert_quantifier(src, "exists", precedence=2); + return convert_quantifier( + to_quantifier_expr(src), "exists", precedence = 2); else if(src.id()==ID_lambda) - return convert_quantifier(src, "LAMBDA", precedence=2); + return convert_quantifier( + to_quantifier_expr(src), "LAMBDA", precedence = 2); else if(src.id()==ID_with) return convert_with(src, precedence=16); diff --git a/src/ansi-c/expr2c_class.h b/src/ansi-c/expr2c_class.h index bfa8233ba2d..a28502504e7 100644 --- a/src/ansi-c/expr2c_class.h +++ b/src/ansi-c/expr2c_class.h @@ -128,7 +128,8 @@ class expr2ct const exprt &src, unsigned &precedence); std::string convert_quantifier( - const exprt &src, const std::string &symbol, + const quantifier_exprt &, + const std::string &symbol, unsigned precedence); std::string convert_with( diff --git a/src/ansi-c/parser.y b/src/ansi-c/parser.y index 0c970cbfd36..403194d3010 100644 --- a/src/ansi-c/parser.y +++ b/src/ansi-c/parser.y @@ -26,6 +26,8 @@ extern char *yyansi_ctext; #include "ansi_c_y.tab.h" +#include + #ifdef _MSC_VER // possible loss of data #pragma warning(disable:4242) @@ -469,7 +471,7 @@ quantifier_expression: { $$=$1; set($$, ID_forall); - mto($$, $4); + parser_stack($$).add_to_operands(tuple_exprt( { std::move(parser_stack($4)) } )); mto($$, $5); PARSER.pop_scope(); } @@ -477,7 +479,7 @@ quantifier_expression: { $$=$1; set($$, ID_exists); - mto($$, $4); + parser_stack($$).add_to_operands(tuple_exprt( { std::move(parser_stack($4)) } )); mto($$, $5); PARSER.pop_scope(); } @@ -810,7 +812,7 @@ ACSL_binding_expression: { $$=$1; set($$, ID_forall); - mto($$, $3); + parser_stack($$).add_to_operands(tuple_exprt( { std::move(parser_stack($3)) } )); mto($$, $4); PARSER.pop_scope(); } @@ -818,7 +820,7 @@ ACSL_binding_expression: { $$=$1; set($$, ID_exists); - mto($$, $3); + parser_stack($$).add_to_operands(tuple_exprt( { std::move(parser_stack($3)) } )); mto($$, $4); PARSER.pop_scope(); } diff --git a/src/solvers/flattening/boolbv_quantifier.cpp b/src/solvers/flattening/boolbv_quantifier.cpp index aff9ca767a8..d28a3d9f66a 100644 --- a/src/solvers/flattening/boolbv_quantifier.cpp +++ b/src/solvers/flattening/boolbv_quantifier.cpp @@ -140,8 +140,8 @@ instantiate_quantifier(const quantifier_exprt &expr, const namespacet &ns) return re; } - const auto min_i = get_quantifier_var_min(var_expr, re); - const auto max_i = get_quantifier_var_max(var_expr, re); + const optionalt min_i = get_quantifier_var_min(var_expr, re); + const optionalt max_i = get_quantifier_var_max(var_expr, re); if(!min_i.has_value() || !max_i.has_value()) return nullopt; diff --git a/src/solvers/smt2/smt2_conv.cpp b/src/solvers/smt2/smt2_conv.cpp index 96e0e998a1a..c68ff60fc2d 100644 --- a/src/solvers/smt2/smt2_conv.cpp +++ b/src/solvers/smt2/smt2_conv.cpp @@ -1807,7 +1807,7 @@ void smt2_convt::convert_expr(const exprt &expr) else if(quantifier_expr.id() == ID_exists) out << "(exists "; - exprt bound=expr.op0(); + exprt bound = quantifier_expr.symbol(); out << "(("; convert_expr(bound); diff --git a/src/solvers/smt2/smt2_parser.cpp b/src/solvers/smt2/smt2_parser.cpp index ad422f5b8ea..dc79a2dcf4a 100644 --- a/src/solvers/smt2/smt2_parser.cpp +++ b/src/solvers/smt2/smt2_parser.cpp @@ -283,7 +283,7 @@ exprt smt2_parsert::quantifier_expression(irep_idt id) // go backwards, build quantified expression for(auto r_it=bindings.rbegin(); r_it!=bindings.rend(); r_it++) { - binary_predicate_exprt quantifier(*r_it, id, result); + quantifier_exprt quantifier(id, *r_it, result); result=quantifier; } diff --git a/src/util/mathematical_expr.h b/src/util/mathematical_expr.h index 3751ab9560f..591f6f468ce 100644 --- a/src/util/mathematical_expr.h +++ b/src/util/mathematical_expr.h @@ -271,26 +271,34 @@ inline function_application_exprt &to_function_application_expr(exprt &expr) return ret; } -/// \brief A base class for quantifier expressions -class quantifier_exprt : public binary_predicate_exprt +/// \brief A base class for variable bindings (quantifiers, let, lambda) +class binding_exprt : public binary_exprt { public: - quantifier_exprt( - const irep_idt &_id, - const symbol_exprt &_symbol, - const exprt &_where) - : binary_predicate_exprt(_symbol, _id, _where) + using variablest = std::vector; + + /// construct the binding expression + binding_exprt( + irep_idt _id, + const variablest &_variables, + exprt _where, + typet _type) + : binary_exprt( + tuple_exprt((const operandst &)_variables), + _id, + std::move(_where), + std::move(_type)) { } - symbol_exprt &symbol() + variablest &variables() { - return static_cast(op0()); + return (variablest &)static_cast(op0()).operands(); } - const symbol_exprt &symbol() const + const variablest &variables() const { - return static_cast(op0()); + return (variablest &)static_cast(op0()).operands(); } exprt &where() @@ -304,6 +312,39 @@ class quantifier_exprt : public binary_predicate_exprt } }; +/// \brief A base class for quantifier expressions +class quantifier_exprt : public binding_exprt +{ +public: + /// constructor for single variable + quantifier_exprt(irep_idt _id, symbol_exprt _symbol, exprt _where) + : binding_exprt(_id, {std::move(_symbol)}, std::move(_where), bool_typet()) + { + } + + /// constructor for multiple variables + quantifier_exprt(irep_idt _id, const variablest &_variables, exprt _where) + : binding_exprt(_id, _variables, std::move(_where), bool_typet()) + { + } + + // for the special case of one variable + symbol_exprt &symbol() + { + auto &variables = this->variables(); + PRECONDITION(variables.size() == 1); + return variables.front(); + } + + // for the special case of one variable + const symbol_exprt &symbol() const + { + auto &variables = this->variables(); + PRECONDITION(variables.size() == 1); + return variables.front(); + } +}; + template <> inline bool can_cast_expr(const exprt &base) { @@ -313,8 +354,9 @@ inline bool can_cast_expr(const exprt &base) inline void validate_expr(const quantifier_exprt &value) { validate_operands(value, 2, "quantifier expressions must have two operands"); - DATA_INVARIANT( - value.op0().id() == ID_symbol, "quantified variable shall be a symbol"); + for(auto &op : value.variables()) + DATA_INVARIANT( + op.id() == ID_symbol, "quantified variable shall be a symbol"); } /// \brief Cast an exprt to a \ref quantifier_exprt