diff --git a/src/solvers/flattening/boolbv.cpp b/src/solvers/flattening/boolbv.cpp index 96bc17c42b4..746cf4ae703 100644 --- a/src/solvers/flattening/boolbv.cpp +++ b/src/solvers/flattening/boolbv.cpp @@ -294,7 +294,7 @@ bvt boolbvt::convert_bitvector(const exprt &expr) else if(expr.id()==ID_complex_imag) return convert_complex_imag(to_complex_imag_expr(expr)); else if(expr.id()==ID_lambda) - return convert_lambda(expr); + return convert_lambda(to_lambda_expr(expr)); else if(expr.id()==ID_array_of) return convert_array_of(to_array_of_expr(expr)); else if(expr.id()==ID_let) @@ -316,28 +316,21 @@ bvt boolbvt::convert_bitvector(const exprt &expr) return conversion_failed(expr); } -bvt boolbvt::convert_lambda(const exprt &expr) +bvt boolbvt::convert_lambda(const lambda_exprt &expr) { std::size_t width=boolbv_width(expr.type()); if(width==0) return conversion_failed(expr); - DATA_INVARIANT( - expr.operands().size() == 2, "lambda expression should have two operands"); - - if(expr.type().id()!=ID_array) - return conversion_failed(expr); - - const exprt &array_size= - to_array_type(expr.type()).size(); + const exprt &array_size = expr.type().size(); const auto size = numeric_cast(array_size); if(!size.has_value()) return conversion_failed(expr); - typet counter_type=expr.op0().type(); + typet counter_type = expr.arg().type(); bvt bv; bv.resize(width); @@ -346,10 +339,10 @@ bvt boolbvt::convert_lambda(const exprt &expr) { exprt counter=from_integer(i, counter_type); - exprt expr_op1(expr.op1()); - replace_expr(expr.op0(), counter, expr_op1); + exprt body = expr.body(); + replace_expr(expr.arg(), counter, body); - const bvt &tmp=convert_bv(expr_op1); + const bvt &tmp = convert_bv(body); INVARIANT( *size * tmp.size() == width, diff --git a/src/solvers/flattening/boolbv.h b/src/solvers/flattening/boolbv.h index b2c87f2b881..4bbb9d853e9 100644 --- a/src/solvers/flattening/boolbv.h +++ b/src/solvers/flattening/boolbv.h @@ -28,6 +28,7 @@ Author: Daniel Kroening, kroening@kroening.com class extractbit_exprt; class extractbits_exprt; +class lambda_exprt; class member_exprt; class boolbvt:public arrayst @@ -148,7 +149,7 @@ class boolbvt:public arrayst virtual bvt convert_complex(const complex_exprt &expr); virtual bvt convert_complex_real(const complex_real_exprt &expr); virtual bvt convert_complex_imag(const complex_imag_exprt &expr); - virtual bvt convert_lambda(const exprt &expr); + virtual bvt convert_lambda(const lambda_exprt &expr); virtual bvt convert_let(const let_exprt &); virtual bvt convert_array_of(const array_of_exprt &expr); virtual bvt convert_union(const union_exprt &expr); diff --git a/src/util/simplify_expr_array.cpp b/src/util/simplify_expr_array.cpp index 4ef9a540706..8ade73f117a 100644 --- a/src/util/simplify_expr_array.cpp +++ b/src/util/simplify_expr_array.cpp @@ -48,16 +48,14 @@ bool simplify_exprt::simplify_index(exprt &expr) { // simplify (lambda i: e)(x) to e[i/x] - const exprt &lambda_expr=array; + const lambda_exprt &lambda_expr = to_lambda_expr(array); - if(lambda_expr.operands().size()!=2) - return true; - - if(expr.op1().type()==lambda_expr.op0().type()) + if(expr.op1().type() == lambda_expr.arg().type()) { - exprt tmp=lambda_expr.op1(); - replace_expr(lambda_expr.op0(), expr.op1(), tmp); + exprt tmp = lambda_expr.body(); + replace_expr(lambda_expr.arg(), expr.op1(), tmp); expr.swap(tmp); + simplify_rec(expr); return false; } } diff --git a/src/util/std_expr.h b/src/util/std_expr.h index 4ff695e7a34..929afb7cbd8 100644 --- a/src/util/std_expr.h +++ b/src/util/std_expr.h @@ -4460,4 +4460,79 @@ inline cond_exprt &to_cond_expr(exprt &expr) return ret; } +/// \brief Expression to define a mapping from an argument (index) to elements. +/// This enables constructing an array via an anonymous function. +class lambda_exprt : public binary_exprt +{ +public: + explicit lambda_exprt(symbol_exprt arg, exprt body, array_typet _type) + : binary_exprt(std::move(arg), ID_lambda, std::move(body), std::move(_type)) + { + } + + const array_typet &type() const + { + return static_cast(binary_exprt::type()); + } + + array_typet &type() + { + return static_cast(binary_exprt::type()); + } + + const symbol_exprt &arg() const + { + return static_cast(op0()); + } + + symbol_exprt &arg() + { + return static_cast(op0()); + } + + const exprt &body() const + { + return op1(); + } + + exprt &body() + { + return op1(); + } +}; + +template <> +inline bool can_cast_expr(const exprt &base) +{ + return base.id() == ID_lambda; +} + +inline void validate_expr(const lambda_exprt &value) +{ + validate_operands(value, 2, "'Lambda' must have two operands"); +} + +/// \brief Cast an exprt to a \ref lambda_exprt +/// +/// \a expr must be known to be \ref lambda_exprt. +/// +/// \param expr: Source expression +/// \return Object of type \ref lambda_exprt +inline const lambda_exprt &to_lambda_expr(const exprt &expr) +{ + PRECONDITION(expr.id() == ID_lambda); + const lambda_exprt &ret = static_cast(expr); + validate_expr(ret); + return ret; +} + +/// \copydoc to_lambda_expr(const exprt &) +inline lambda_exprt &to_lambda_expr(exprt &expr) +{ + PRECONDITION(expr.id() == ID_lambda); + lambda_exprt &ret = static_cast(expr); + validate_expr(ret); + return ret; +} + #endif // CPROVER_UTIL_STD_EXPR_H