Skip to content

Implement conversion to SMT for more bit wise expressions #6722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 62 additions & 23 deletions src/solvers/smt2_incremental/convert_expr_to_smt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ static smt_termt convert_c_bool_cast(
smt_bit_vector_constant_termt{0, c_bool_width});
}

static std::function<std::function<smt_termt(smt_termt)>(std::size_t)>
extension_for_type(const typet &type)
{
if(can_cast_type<signedbv_typet>(type))
return smt_bit_vector_theoryt::sign_extend;
if(can_cast_type<unsignedbv_typet>(type))
return smt_bit_vector_theoryt::zero_extend;
UNREACHABLE;
}

static smt_termt make_bitvector_resize_cast(
const smt_termt &from_term,
const bitvector_typet &from_type,
Expand All @@ -147,10 +157,7 @@ static smt_termt make_bitvector_resize_cast(
if(to_width < from_width)
return smt_bit_vector_theoryt::extract(to_width - 1, 0)(from_term);
const std::size_t extension_size = to_width - from_width;
if(can_cast_type<signedbv_typet>(from_type))
return smt_bit_vector_theoryt::sign_extend(extension_size)(from_term);
else
return smt_bit_vector_theoryt::zero_extend(extension_size)(from_term);
return extension_for_type(from_type)(extension_size)(from_term);
}

struct sort_based_cast_to_bit_vector_convertert final
Expand Down Expand Up @@ -652,34 +659,61 @@ static smt_termt convert_expr_to_smt(const index_exprt &index)
"Generation of SMT formula for index expression: " + index.pretty());
}

template <typename factoryt, typename shiftt>
static smt_termt
convert_to_smt_shift(const factoryt &factory, const shiftt &shift)
{
const smt_termt first_operand = convert_expr_to_smt(shift.op0());
const smt_termt second_operand = convert_expr_to_smt(shift.op1());
const auto first_bit_vector_sort =
first_operand.get_sort().cast<smt_bit_vector_sortt>();
const auto second_bit_vector_sort =
second_operand.get_sort().cast<smt_bit_vector_sortt>();
INVARIANT(
first_bit_vector_sort && second_bit_vector_sort,
"Shift expressions are expected to have bit vector operands.");
const std::size_t first_width = first_bit_vector_sort->bit_width();
const std::size_t second_width = second_bit_vector_sort->bit_width();
if(first_width > second_width)
{
return factory(
first_operand,
extension_for_type(shift.op1().type())(first_width - second_width)(
second_operand));
}
else if(first_width < second_width)
{
return factory(
extension_for_type(shift.op0().type())(second_width - first_width)(
first_operand),
second_operand);
}
else
{
return factory(first_operand, second_operand);
}
}

static smt_termt convert_expr_to_smt(const shift_exprt &shift)
{
// TODO: Dispatch into different types of shifting
const auto &first_operand = shift.op0();
const auto &second_operand = shift.op1();

// TODO: Dispatch for rotation expressions. A `shift_exprt` can be a rotation.
if(const auto left_shift = expr_try_dynamic_cast<shl_exprt>(shift))
{
return smt_bit_vector_theoryt::shift_left(
convert_expr_to_smt(first_operand), convert_expr_to_smt(second_operand));
return convert_to_smt_shift(
smt_bit_vector_theoryt::shift_left, *left_shift);
}
else if(
const auto right_logical_shift = expr_try_dynamic_cast<lshr_exprt>(shift))
if(const auto right_logical_shift = expr_try_dynamic_cast<lshr_exprt>(shift))
{
return smt_bit_vector_theoryt::logical_shift_right(
convert_expr_to_smt(first_operand), convert_expr_to_smt(second_operand));
return convert_to_smt_shift(
smt_bit_vector_theoryt::logical_shift_right, *right_logical_shift);
}
else if(
const auto right_arith_shift = expr_try_dynamic_cast<ashr_exprt>(shift))
if(const auto right_arith_shift = expr_try_dynamic_cast<ashr_exprt>(shift))
{
return smt_bit_vector_theoryt::arithmetic_shift_right(
convert_expr_to_smt(first_operand), convert_expr_to_smt(second_operand));
}
else
{
UNIMPLEMENTED_FEATURE(
"Generation of SMT formula for shift expression: " + shift.pretty());
return convert_to_smt_shift(
smt_bit_vector_theoryt::arithmetic_shift_right, *right_arith_shift);
}
UNIMPLEMENTED_FEATURE(
"Generation of SMT formula for shift expression: " + shift.pretty());
}

static smt_termt convert_expr_to_smt(const with_exprt &with)
Expand Down Expand Up @@ -733,6 +767,11 @@ static smt_termt convert_expr_to_smt(const extractbit_exprt &extract_bit)

static smt_termt convert_expr_to_smt(const extractbits_exprt &extract_bits)
{
const smt_termt from = convert_expr_to_smt(extract_bits.src());
const auto upper_value = numeric_cast<std::size_t>(extract_bits.upper());
const auto lower_value = numeric_cast<std::size_t>(extract_bits.lower());
if(upper_value && lower_value)
return smt_bit_vector_theoryt::extract(*upper_value, *lower_value)(from);
UNIMPLEMENTED_FEATURE(
"Generation of SMT formula for extract bits expression: " +
extract_bits.pretty());
Expand Down
99 changes: 99 additions & 0 deletions unit/solvers/smt2_incremental/convert_expr_to_smt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,105 @@ SCENARIO(
}
}

TEST_CASE(
"expr to smt conversion for shifts of mismatched operands.",
"[core][smt2_incremental]")
{
using make_typet = std::function<typet(std::size_t)>;
const make_typet make_unsigned = constructor_oft<unsignedbv_typet>{};
const make_typet make_signed = constructor_oft<signedbv_typet>{};
using make_extensiont =
std::function<std::function<smt_termt(smt_termt)>(std::size_t)>;
const make_extensiont zero_extend = smt_bit_vector_theoryt::zero_extend;
const make_extensiont sign_extend = smt_bit_vector_theoryt::sign_extend;
std::string type_description;
make_typet make_type;
make_extensiont make_extension;
using type_rowt = std::tuple<std::string, make_typet, make_extensiont>;
std::tie(type_description, make_type, make_extension) = GENERATE_REF(
type_rowt{"Unsigned operands.", make_unsigned, zero_extend},
type_rowt{"Signed operands.", make_signed, sign_extend});
SECTION(type_description)
{
using make_shift_exprt = std::function<exprt(exprt, exprt)>;
const make_shift_exprt left_shift_expr = constructor_of<shl_exprt>();
const make_shift_exprt arithmetic_right_shift_expr =
constructor_of<ashr_exprt>();
const make_shift_exprt logical_right_shift_expr =
constructor_of<lshr_exprt>();
using make_shift_termt = std::function<smt_termt(smt_termt, smt_termt)>;
const make_shift_termt left_shift_term = smt_bit_vector_theoryt::shift_left;
const make_shift_termt arithmetic_right_shift_term =
smt_bit_vector_theoryt::arithmetic_shift_right;
const make_shift_termt logical_right_shift_term =
smt_bit_vector_theoryt::logical_shift_right;
std::string shift_description;
make_shift_exprt make_shift_expr;
make_shift_termt make_shift_term;
using shift_rowt =
std::tuple<std::string, make_shift_exprt, make_shift_termt>;
std::tie(shift_description, make_shift_expr, make_shift_term) =
GENERATE_REF(
shift_rowt{"Left shift.", left_shift_expr, left_shift_term},
shift_rowt{
"Arithmetic right shift.",
arithmetic_right_shift_expr,
arithmetic_right_shift_term},
shift_rowt{
"Logical right shift.",
logical_right_shift_expr,
logical_right_shift_term});
SECTION(shift_description)
{
SECTION("Wider left hand side")
{
const exprt input = make_shift_expr(
symbol_exprt{"foo", make_type(32)},
symbol_exprt{"bar", make_type(8)});
INFO("Input expr: " + input.pretty(2, 0));
const smt_termt expected_result = make_shift_term(
smt_identifier_termt{"foo", smt_bit_vector_sortt{32}},
make_extension(24)(
smt_identifier_termt{"bar", smt_bit_vector_sortt{8}}));
CHECK(convert_expr_to_smt(input) == expected_result);
}
SECTION("Wider right hand side")
{
const exprt input = make_shift_expr(
symbol_exprt{"foo", make_type(8)},
symbol_exprt{"bar", make_type(32)});
INFO("Input expr: " + input.pretty(2, 0));
const smt_termt expected_result = make_shift_term(
make_extension(24)(
smt_identifier_termt{"foo", smt_bit_vector_sortt{8}}),
smt_identifier_termt{"bar", smt_bit_vector_sortt{32}});
CHECK(convert_expr_to_smt(input) == expected_result);
}
}
}
}

TEST_CASE(
"expr to smt conversion for extract bits expressions",
"[core][smt2_incremental]")
{
const typet operand_type = unsignedbv_typet{8};
const exprt input = extractbits_exprt{
symbol_exprt{"foo", operand_type},
from_integer(4, operand_type),
from_integer(2, operand_type),
unsignedbv_typet{3}};
const smt_termt expected_result = smt_bit_vector_theoryt::extract(4, 2)(
smt_identifier_termt{"foo", smt_bit_vector_sortt{8}});
CHECK(convert_expr_to_smt(input) == expected_result);
const cbmc_invariants_should_throwt invariants_throw;
CHECK_THROWS(convert_expr_to_smt(extractbits_exprt{
symbol_exprt{"foo", operand_type},
symbol_exprt{"bar", operand_type},
symbol_exprt{"bar", operand_type},
unsignedbv_typet{3}}));
}

TEST_CASE("expr to smt conversion for type casts", "[core][smt2_incremental]")
{
const symbol_exprt bool_expr{"foo", bool_typet{}};
Expand Down