diff --git a/src/solvers/smt2_incremental/convert_expr_to_smt.cpp b/src/solvers/smt2_incremental/convert_expr_to_smt.cpp index 18ddc576f12..6d5fb84fd93 100644 --- a/src/solvers/smt2_incremental/convert_expr_to_smt.cpp +++ b/src/solvers/smt2_incremental/convert_expr_to_smt.cpp @@ -121,6 +121,16 @@ static smt_termt convert_c_bool_cast( smt_bit_vector_constant_termt{0, c_bool_width}); } +static std::function(std::size_t)> +extension_for_type(const typet &type) +{ + if(can_cast_type(type)) + return smt_bit_vector_theoryt::sign_extend; + if(can_cast_type(type)) + return smt_bit_vector_theoryt::zero_extend; + UNREACHABLE; +} + static smt_termt make_bitvector_resize_cast( const smt_termt &from_term, const bitvector_typet &from_type, @@ -147,10 +157,7 @@ static smt_termt make_bitvector_resize_cast( if(to_width < from_width) return smt_bit_vector_theoryt::extract(to_width - 1, 0)(from_term); const std::size_t extension_size = to_width - from_width; - if(can_cast_type(from_type)) - return smt_bit_vector_theoryt::sign_extend(extension_size)(from_term); - else - return smt_bit_vector_theoryt::zero_extend(extension_size)(from_term); + return extension_for_type(from_type)(extension_size)(from_term); } struct sort_based_cast_to_bit_vector_convertert final @@ -652,34 +659,61 @@ static smt_termt convert_expr_to_smt(const index_exprt &index) "Generation of SMT formula for index expression: " + index.pretty()); } +template +static smt_termt +convert_to_smt_shift(const factoryt &factory, const shiftt &shift) +{ + const smt_termt first_operand = convert_expr_to_smt(shift.op0()); + const smt_termt second_operand = convert_expr_to_smt(shift.op1()); + const auto first_bit_vector_sort = + first_operand.get_sort().cast(); + const auto second_bit_vector_sort = + second_operand.get_sort().cast(); + INVARIANT( + first_bit_vector_sort && second_bit_vector_sort, + "Shift expressions are expected to have bit vector operands."); + const std::size_t first_width = first_bit_vector_sort->bit_width(); + const std::size_t second_width = second_bit_vector_sort->bit_width(); + if(first_width > second_width) + { + return factory( + first_operand, + extension_for_type(shift.op1().type())(first_width - second_width)( + second_operand)); + } + else if(first_width < second_width) + { + return factory( + extension_for_type(shift.op0().type())(second_width - first_width)( + first_operand), + second_operand); + } + else + { + return factory(first_operand, second_operand); + } +} + static smt_termt convert_expr_to_smt(const shift_exprt &shift) { - // TODO: Dispatch into different types of shifting - const auto &first_operand = shift.op0(); - const auto &second_operand = shift.op1(); - + // TODO: Dispatch for rotation expressions. A `shift_exprt` can be a rotation. if(const auto left_shift = expr_try_dynamic_cast(shift)) { - return smt_bit_vector_theoryt::shift_left( - convert_expr_to_smt(first_operand), convert_expr_to_smt(second_operand)); + return convert_to_smt_shift( + smt_bit_vector_theoryt::shift_left, *left_shift); } - else if( - const auto right_logical_shift = expr_try_dynamic_cast(shift)) + if(const auto right_logical_shift = expr_try_dynamic_cast(shift)) { - return smt_bit_vector_theoryt::logical_shift_right( - convert_expr_to_smt(first_operand), convert_expr_to_smt(second_operand)); + return convert_to_smt_shift( + smt_bit_vector_theoryt::logical_shift_right, *right_logical_shift); } - else if( - const auto right_arith_shift = expr_try_dynamic_cast(shift)) + if(const auto right_arith_shift = expr_try_dynamic_cast(shift)) { - return smt_bit_vector_theoryt::arithmetic_shift_right( - convert_expr_to_smt(first_operand), convert_expr_to_smt(second_operand)); - } - else - { - UNIMPLEMENTED_FEATURE( - "Generation of SMT formula for shift expression: " + shift.pretty()); + return convert_to_smt_shift( + smt_bit_vector_theoryt::arithmetic_shift_right, *right_arith_shift); } + UNIMPLEMENTED_FEATURE( + "Generation of SMT formula for shift expression: " + shift.pretty()); } static smt_termt convert_expr_to_smt(const with_exprt &with) @@ -733,6 +767,11 @@ static smt_termt convert_expr_to_smt(const extractbit_exprt &extract_bit) static smt_termt convert_expr_to_smt(const extractbits_exprt &extract_bits) { + const smt_termt from = convert_expr_to_smt(extract_bits.src()); + const auto upper_value = numeric_cast(extract_bits.upper()); + const auto lower_value = numeric_cast(extract_bits.lower()); + if(upper_value && lower_value) + return smt_bit_vector_theoryt::extract(*upper_value, *lower_value)(from); UNIMPLEMENTED_FEATURE( "Generation of SMT formula for extract bits expression: " + extract_bits.pretty()); diff --git a/unit/solvers/smt2_incremental/convert_expr_to_smt.cpp b/unit/solvers/smt2_incremental/convert_expr_to_smt.cpp index 3aec501617d..3b7393f5ce5 100644 --- a/unit/solvers/smt2_incremental/convert_expr_to_smt.cpp +++ b/unit/solvers/smt2_incremental/convert_expr_to_smt.cpp @@ -819,6 +819,105 @@ SCENARIO( } } +TEST_CASE( + "expr to smt conversion for shifts of mismatched operands.", + "[core][smt2_incremental]") +{ + using make_typet = std::function; + const make_typet make_unsigned = constructor_oft{}; + const make_typet make_signed = constructor_oft{}; + using make_extensiont = + std::function(std::size_t)>; + const make_extensiont zero_extend = smt_bit_vector_theoryt::zero_extend; + const make_extensiont sign_extend = smt_bit_vector_theoryt::sign_extend; + std::string type_description; + make_typet make_type; + make_extensiont make_extension; + using type_rowt = std::tuple; + std::tie(type_description, make_type, make_extension) = GENERATE_REF( + type_rowt{"Unsigned operands.", make_unsigned, zero_extend}, + type_rowt{"Signed operands.", make_signed, sign_extend}); + SECTION(type_description) + { + using make_shift_exprt = std::function; + const make_shift_exprt left_shift_expr = constructor_of(); + const make_shift_exprt arithmetic_right_shift_expr = + constructor_of(); + const make_shift_exprt logical_right_shift_expr = + constructor_of(); + using make_shift_termt = std::function; + const make_shift_termt left_shift_term = smt_bit_vector_theoryt::shift_left; + const make_shift_termt arithmetic_right_shift_term = + smt_bit_vector_theoryt::arithmetic_shift_right; + const make_shift_termt logical_right_shift_term = + smt_bit_vector_theoryt::logical_shift_right; + std::string shift_description; + make_shift_exprt make_shift_expr; + make_shift_termt make_shift_term; + using shift_rowt = + std::tuple; + std::tie(shift_description, make_shift_expr, make_shift_term) = + GENERATE_REF( + shift_rowt{"Left shift.", left_shift_expr, left_shift_term}, + shift_rowt{ + "Arithmetic right shift.", + arithmetic_right_shift_expr, + arithmetic_right_shift_term}, + shift_rowt{ + "Logical right shift.", + logical_right_shift_expr, + logical_right_shift_term}); + SECTION(shift_description) + { + SECTION("Wider left hand side") + { + const exprt input = make_shift_expr( + symbol_exprt{"foo", make_type(32)}, + symbol_exprt{"bar", make_type(8)}); + INFO("Input expr: " + input.pretty(2, 0)); + const smt_termt expected_result = make_shift_term( + smt_identifier_termt{"foo", smt_bit_vector_sortt{32}}, + make_extension(24)( + smt_identifier_termt{"bar", smt_bit_vector_sortt{8}})); + CHECK(convert_expr_to_smt(input) == expected_result); + } + SECTION("Wider right hand side") + { + const exprt input = make_shift_expr( + symbol_exprt{"foo", make_type(8)}, + symbol_exprt{"bar", make_type(32)}); + INFO("Input expr: " + input.pretty(2, 0)); + const smt_termt expected_result = make_shift_term( + make_extension(24)( + smt_identifier_termt{"foo", smt_bit_vector_sortt{8}}), + smt_identifier_termt{"bar", smt_bit_vector_sortt{32}}); + CHECK(convert_expr_to_smt(input) == expected_result); + } + } + } +} + +TEST_CASE( + "expr to smt conversion for extract bits expressions", + "[core][smt2_incremental]") +{ + const typet operand_type = unsignedbv_typet{8}; + const exprt input = extractbits_exprt{ + symbol_exprt{"foo", operand_type}, + from_integer(4, operand_type), + from_integer(2, operand_type), + unsignedbv_typet{3}}; + const smt_termt expected_result = smt_bit_vector_theoryt::extract(4, 2)( + smt_identifier_termt{"foo", smt_bit_vector_sortt{8}}); + CHECK(convert_expr_to_smt(input) == expected_result); + const cbmc_invariants_should_throwt invariants_throw; + CHECK_THROWS(convert_expr_to_smt(extractbits_exprt{ + symbol_exprt{"foo", operand_type}, + symbol_exprt{"bar", operand_type}, + symbol_exprt{"bar", operand_type}, + unsignedbv_typet{3}})); +} + TEST_CASE("expr to smt conversion for type casts", "[core][smt2_incremental]") { const symbol_exprt bool_expr{"foo", bool_typet{}};