diff --git a/jbmc/unit/util/simplify_expr.cpp b/jbmc/unit/util/simplify_expr.cpp index 4d2cebe0907..5e1ec18b059 100644 --- a/jbmc/unit/util/simplify_expr.cpp +++ b/jbmc/unit/util/simplify_expr.cpp @@ -47,6 +47,17 @@ void test_unnecessary_cast(const typet &type) REQUIRE(simplified.type()==java_int_type()); } + // casts from boolean get rewritten to ?: + if(type == java_boolean_type()) + { + const exprt simplified = simplify_expr( + typecast_exprt(symbol_exprt("foo", java_int_type()), type), + namespacet(symbol_tablet())); + + REQUIRE(simplified.id() == ID_if); + REQUIRE(simplified.type() == type); + } + else { const exprt simplified=simplify_expr( typecast_exprt(symbol_exprt("foo", java_int_type()), type), diff --git a/src/util/simplify_expr.cpp b/src/util/simplify_expr.cpp index 266bbc0318d..09870dbdb9c 100644 --- a/src/util/simplify_expr.cpp +++ b/src/util/simplify_expr.cpp @@ -576,6 +576,19 @@ simplify_exprt::simplify_typecast(const typecast_exprt &expr) return std::move(inequality); } + // eliminate casts from proper bool + if( + op_type.id() == ID_bool && + (expr_type.id() == ID_signedbv || expr_type.id() == ID_unsignedbv || + expr_type.id() == ID_c_bool || expr_type.id() == ID_c_bit_field)) + { + // rewrite (T)(bool) to bool?1:0 + auto one = from_integer(1, expr_type); + auto zero = from_integer(0, expr_type); + exprt new_expr = if_exprt(expr.op(), std::move(one), std::move(zero)); + return changed(simplify_rec(new_expr)); // recursive call + } + // circular casts through types shorter than `int` if(op_type == signedbv_typet(32) && expr.op().id() == ID_typecast) { diff --git a/unit/util/simplify_expr.cpp b/unit/util/simplify_expr.cpp index c55f13e52d3..cf837a62109 100644 --- a/unit/util/simplify_expr.cpp +++ b/unit/util/simplify_expr.cpp @@ -244,3 +244,57 @@ TEST_CASE("Simplify pointer_object equality", "[core][util]") REQUIRE(simp.is_true()); } + +TEST_CASE("Simplify cast from bool", "[core][util]") +{ + symbol_tablet symbol_table; + namespacet ns(symbol_table); + + { + // this checks that ((int)B)==1 turns into B + exprt B = symbol_exprt("B", bool_typet()); + exprt comparison = equal_exprt( + typecast_exprt(B, signedbv_typet(32)), + from_integer(1, signedbv_typet(32))); + + exprt simp = simplify_expr(comparison, ns); + + REQUIRE(simp == B); + } + + { + // this checks that ((int)B)==0 turns into !B + exprt B = symbol_exprt("B", bool_typet()); + exprt comparison = equal_exprt( + typecast_exprt(B, signedbv_typet(32)), + from_integer(0, signedbv_typet(32))); + + exprt simp = simplify_expr(comparison, ns); + + REQUIRE(simp == not_exprt(B)); + } + + { + // this checks that ((int)B)!=1 turns into !B + exprt B = symbol_exprt("B", bool_typet()); + exprt comparison = notequal_exprt( + typecast_exprt(B, signedbv_typet(32)), + from_integer(1, signedbv_typet(32))); + + exprt simp = simplify_expr(comparison, ns); + + REQUIRE(simp == not_exprt(B)); + } + + { + // this checks that ((int)B)!=0 turns into B + exprt B = symbol_exprt("B", bool_typet()); + exprt comparison = notequal_exprt( + typecast_exprt(B, signedbv_typet(32)), + from_integer(0, signedbv_typet(32))); + + exprt simp = simplify_expr(comparison, ns); + + REQUIRE(simp == B); + } +}