From 7e39d04c2cc00031296e69171dcb5617a9abb5c0 Mon Sep 17 00:00:00 2001 From: Daniel Kroening Date: Sun, 1 Dec 2024 13:33:12 +0000 Subject: [PATCH] SMT2: support onehot and onehot0 This adds support for the onehot and onehot0 predicates to the SMT2 backend. --- src/solvers/flattening/boolbv.cpp | 6 ++- src/solvers/smt2/smt2_conv.cpp | 8 ++++ src/util/bitvector_expr.cpp | 34 ++++++++++++++ src/util/bitvector_expr.h | 70 +++++++++++++++++++++++++++ unit/util/bitvector_expr.cpp | 78 +++++++++++++++++++++++++++++++ unit/util/module_dependencies.txt | 2 + 6 files changed, 196 insertions(+), 2 deletions(-) diff --git a/src/solvers/flattening/boolbv.cpp b/src/solvers/flattening/boolbv.cpp index f1f1f7c9de7..38f856898be 100644 --- a/src/solvers/flattening/boolbv.cpp +++ b/src/solvers/flattening/boolbv.cpp @@ -412,8 +412,10 @@ literalt boolbvt::convert_rest(const exprt &expr) expr.id()==ID_reduction_nor || expr.id()==ID_reduction_nand || expr.id()==ID_reduction_xor || expr.id()==ID_reduction_xnor) return convert_reduction(to_unary_expr(expr)); - else if(expr.id()==ID_onehot || expr.id()==ID_onehot0) - return convert_onehot(to_unary_expr(expr)); + else if(expr.id() == ID_onehot) + return convert_onehot(to_onehot_expr(expr)); + else if(expr.id() == ID_onehot0) + return convert_onehot(to_onehot0_expr(expr)); else if( const auto binary_overflow = expr_try_dynamic_cast(expr)) diff --git a/src/solvers/smt2/smt2_conv.cpp b/src/solvers/smt2/smt2_conv.cpp index 2402bb8ca92..5a876f828bf 100644 --- a/src/solvers/smt2/smt2_conv.cpp +++ b/src/solvers/smt2/smt2_conv.cpp @@ -1831,6 +1831,14 @@ void smt2_convt::convert_expr(const exprt &expr) out << ")) #b1)"; // bvlshr, extract, = } } + else if(expr.id() == ID_onehot) + { + convert_expr(to_onehot_expr(expr).lower()); + } + else if(expr.id() == ID_onehot0) + { + convert_expr(to_onehot0_expr(expr).lower()); + } else if(expr.id()==ID_extractbits) { const extractbits_exprt &extractbits_expr = to_extractbits_expr(expr); diff --git a/src/util/bitvector_expr.cpp b/src/util/bitvector_expr.cpp index f36b5f14602..e578f52b6ea 100644 --- a/src/util/bitvector_expr.cpp +++ b/src/util/bitvector_expr.cpp @@ -306,3 +306,37 @@ exprt zero_extend_exprt::lower() const return extractbits_exprt{op(), 0, type()}; } } + +static exprt onehot_lowering(const exprt &expr) +{ + exprt one_seen = false_exprt{}; + const auto width = to_bitvector_type(expr.type()).get_width(); + exprt::operandst more_than_one_seen_disjuncts; + more_than_one_seen_disjuncts.reserve(width); + + for(std::size_t i = 0; i < width; i++) + { + auto bit = extractbit_exprt{expr, i}; + more_than_one_seen_disjuncts.push_back(and_exprt{bit, one_seen}); + one_seen = or_exprt{one_seen, bit}; + } + + auto more_than_one_seen = disjunction(more_than_one_seen_disjuncts); + + return and_exprt{one_seen, not_exprt{more_than_one_seen}}; +} + +exprt onehot_exprt::lower() const +{ + auto symbol = symbol_exprt{"onehot-op", op().type()}; + + return let_exprt{symbol, op(), onehot_lowering(symbol)}; +} + +exprt onehot0_exprt::lower() const +{ + auto symbol = symbol_exprt{"onehot-op", op().type()}; + + // same as onehot, but on flipped operand bits + return let_exprt{symbol, bitnot_exprt{op()}, onehot_lowering(symbol)}; +} diff --git a/src/util/bitvector_expr.h b/src/util/bitvector_expr.h index ca1573928f8..cd2b2a11c15 100644 --- a/src/util/bitvector_expr.h +++ b/src/util/bitvector_expr.h @@ -1742,4 +1742,74 @@ inline zero_extend_exprt &to_zero_extend_expr(exprt &expr) return static_cast(expr); } +/// \brief A Boolean expression returning true iff the given +/// operand consists of exactly one '1' and '0' otherwise. +class onehot_exprt : public unary_predicate_exprt +{ +public: + explicit onehot_exprt(exprt _op) + : unary_predicate_exprt(ID_onehot, std::move(_op)) + { + } + + /// lowering to extractbit + exprt lower() const; +}; + +/// \brief Cast an exprt to a \ref onehot_exprt +/// +/// \a expr must be known to be \ref onehot_exprt. +/// +/// \param expr: Source expression +/// \return Object of type \ref onehot_exprt +inline const onehot_exprt &to_onehot_expr(const exprt &expr) +{ + PRECONDITION(expr.id() == ID_onehot); + onehot_exprt::check(expr); + return static_cast(expr); +} + +/// \copydoc to_onehot_expr(const exprt &) +inline onehot_exprt &to_onehot_expr(exprt &expr) +{ + PRECONDITION(expr.id() == ID_onehot); + onehot_exprt::check(expr); + return static_cast(expr); +} + +/// \brief A Boolean expression returning true iff the given +/// operand consists of exactly one '0' and '1' otherwise. +class onehot0_exprt : public unary_predicate_exprt +{ +public: + explicit onehot0_exprt(exprt _op) + : unary_predicate_exprt(ID_onehot0, std::move(_op)) + { + } + + /// lowering to extractbit + exprt lower() const; +}; + +/// \brief Cast an exprt to a \ref onehot0_exprt +/// +/// \a expr must be known to be \ref onehot0_exprt. +/// +/// \param expr: Source expression +/// \return Object of type \ref onehot0_exprt +inline const onehot0_exprt &to_onehot0_expr(const exprt &expr) +{ + PRECONDITION(expr.id() == ID_onehot0); + onehot0_exprt::check(expr); + return static_cast(expr); +} + +/// \copydoc to_onehot0_expr(const exprt &) +inline onehot0_exprt &to_onehot0_expr(exprt &expr) +{ + PRECONDITION(expr.id() == ID_onehot0); + onehot0_exprt::check(expr); + return static_cast(expr); +} + #endif // CPROVER_UTIL_BITVECTOR_EXPR_H diff --git a/unit/util/bitvector_expr.cpp b/unit/util/bitvector_expr.cpp index 4b191b56212..bf4207fa11f 100644 --- a/unit/util/bitvector_expr.cpp +++ b/unit/util/bitvector_expr.cpp @@ -1,8 +1,15 @@ // Author: Diffblue Ltd. +#include #include #include +#include +#include +#include +#include +#include +#include #include TEST_CASE( @@ -64,3 +71,74 @@ TEMPLATE_TEST_CASE( } } } + +TEST_CASE("onehot expression lowering", "[core][util][expr]") +{ + console_message_handlert message_handler; + message_handler.set_verbosity(0); + satcheckt satcheck{message_handler}; + symbol_tablet symbol_table; + namespacet ns{symbol_table}; + boolbvt boolbv{ns, satcheck, message_handler}; + unsignedbv_typet u8{8}; + + GIVEN("A bit-vector that is one-hot") + { + boolbv << onehot_exprt{from_integer(64, u8)}.lower(); + + THEN("the lowering of onehot is true") + { + REQUIRE(boolbv() == decision_proceduret::resultt::D_SATISFIABLE); + } + } + + GIVEN("A bit-vector that is not one-hot") + { + boolbv << onehot_exprt{from_integer(5, u8)}.lower(); + + THEN("the lowering of onehot is false") + { + REQUIRE(boolbv() == decision_proceduret::resultt::D_UNSATISFIABLE); + } + } + + GIVEN("A bit-vector that is not one-hot") + { + boolbv << onehot_exprt{from_integer(0, u8)}.lower(); + + THEN("the lowering of onehot is false") + { + REQUIRE(boolbv() == decision_proceduret::resultt::D_UNSATISFIABLE); + } + } + + GIVEN("A bit-vector that is one-hot 0") + { + boolbv << onehot0_exprt{from_integer(0xfe, u8)}.lower(); + + THEN("the lowering of onehot0 is true") + { + REQUIRE(boolbv() == decision_proceduret::resultt::D_SATISFIABLE); + } + } + + GIVEN("A bit-vector that is not one-hot 0") + { + boolbv << onehot0_exprt{from_integer(0x7e, u8)}.lower(); + + THEN("the lowering of onehot0 is false") + { + REQUIRE(boolbv() == decision_proceduret::resultt::D_UNSATISFIABLE); + } + } + + GIVEN("A bit-vector that is not one-hot 0") + { + boolbv << onehot0_exprt{from_integer(0xff, u8)}.lower(); + + THEN("the lowering of onehot0 is false") + { + REQUIRE(boolbv() == decision_proceduret::resultt::D_UNSATISFIABLE); + } + } +} diff --git a/unit/util/module_dependencies.txt b/unit/util/module_dependencies.txt index abf92b3762e..4216fe01a4a 100644 --- a/unit/util/module_dependencies.txt +++ b/unit/util/module_dependencies.txt @@ -1,2 +1,4 @@ testing-utils util +solvers/flattening +solvers/sat