Skip to content

Commit 7127c4d

Browse files
Merge pull request #6722 from thomasspriggs/tas/more_smt_bitwise
Implement conversion to SMT for more bit wise expressions
2 parents 5c3722d + 5ecedf8 commit 7127c4d

File tree

2 files changed

+161
-23
lines changed

2 files changed

+161
-23
lines changed

src/solvers/smt2_incremental/convert_expr_to_smt.cpp

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,16 @@ static smt_termt convert_c_bool_cast(
121121
smt_bit_vector_constant_termt{0, c_bool_width});
122122
}
123123

124+
static std::function<std::function<smt_termt(smt_termt)>(std::size_t)>
125+
extension_for_type(const typet &type)
126+
{
127+
if(can_cast_type<signedbv_typet>(type))
128+
return smt_bit_vector_theoryt::sign_extend;
129+
if(can_cast_type<unsignedbv_typet>(type))
130+
return smt_bit_vector_theoryt::zero_extend;
131+
UNREACHABLE;
132+
}
133+
124134
static smt_termt make_bitvector_resize_cast(
125135
const smt_termt &from_term,
126136
const bitvector_typet &from_type,
@@ -147,10 +157,7 @@ static smt_termt make_bitvector_resize_cast(
147157
if(to_width < from_width)
148158
return smt_bit_vector_theoryt::extract(to_width - 1, 0)(from_term);
149159
const std::size_t extension_size = to_width - from_width;
150-
if(can_cast_type<signedbv_typet>(from_type))
151-
return smt_bit_vector_theoryt::sign_extend(extension_size)(from_term);
152-
else
153-
return smt_bit_vector_theoryt::zero_extend(extension_size)(from_term);
160+
return extension_for_type(from_type)(extension_size)(from_term);
154161
}
155162

156163
struct sort_based_cast_to_bit_vector_convertert final
@@ -652,34 +659,61 @@ static smt_termt convert_expr_to_smt(const index_exprt &index)
652659
"Generation of SMT formula for index expression: " + index.pretty());
653660
}
654661

662+
template <typename factoryt, typename shiftt>
663+
static smt_termt
664+
convert_to_smt_shift(const factoryt &factory, const shiftt &shift)
665+
{
666+
const smt_termt first_operand = convert_expr_to_smt(shift.op0());
667+
const smt_termt second_operand = convert_expr_to_smt(shift.op1());
668+
const auto first_bit_vector_sort =
669+
first_operand.get_sort().cast<smt_bit_vector_sortt>();
670+
const auto second_bit_vector_sort =
671+
second_operand.get_sort().cast<smt_bit_vector_sortt>();
672+
INVARIANT(
673+
first_bit_vector_sort && second_bit_vector_sort,
674+
"Shift expressions are expected to have bit vector operands.");
675+
const std::size_t first_width = first_bit_vector_sort->bit_width();
676+
const std::size_t second_width = second_bit_vector_sort->bit_width();
677+
if(first_width > second_width)
678+
{
679+
return factory(
680+
first_operand,
681+
extension_for_type(shift.op1().type())(first_width - second_width)(
682+
second_operand));
683+
}
684+
else if(first_width < second_width)
685+
{
686+
return factory(
687+
extension_for_type(shift.op0().type())(second_width - first_width)(
688+
first_operand),
689+
second_operand);
690+
}
691+
else
692+
{
693+
return factory(first_operand, second_operand);
694+
}
695+
}
696+
655697
static smt_termt convert_expr_to_smt(const shift_exprt &shift)
656698
{
657-
// TODO: Dispatch into different types of shifting
658-
const auto &first_operand = shift.op0();
659-
const auto &second_operand = shift.op1();
660-
699+
// TODO: Dispatch for rotation expressions. A `shift_exprt` can be a rotation.
661700
if(const auto left_shift = expr_try_dynamic_cast<shl_exprt>(shift))
662701
{
663-
return smt_bit_vector_theoryt::shift_left(
664-
convert_expr_to_smt(first_operand), convert_expr_to_smt(second_operand));
702+
return convert_to_smt_shift(
703+
smt_bit_vector_theoryt::shift_left, *left_shift);
665704
}
666-
else if(
667-
const auto right_logical_shift = expr_try_dynamic_cast<lshr_exprt>(shift))
705+
if(const auto right_logical_shift = expr_try_dynamic_cast<lshr_exprt>(shift))
668706
{
669-
return smt_bit_vector_theoryt::logical_shift_right(
670-
convert_expr_to_smt(first_operand), convert_expr_to_smt(second_operand));
707+
return convert_to_smt_shift(
708+
smt_bit_vector_theoryt::logical_shift_right, *right_logical_shift);
671709
}
672-
else if(
673-
const auto right_arith_shift = expr_try_dynamic_cast<ashr_exprt>(shift))
710+
if(const auto right_arith_shift = expr_try_dynamic_cast<ashr_exprt>(shift))
674711
{
675-
return smt_bit_vector_theoryt::arithmetic_shift_right(
676-
convert_expr_to_smt(first_operand), convert_expr_to_smt(second_operand));
677-
}
678-
else
679-
{
680-
UNIMPLEMENTED_FEATURE(
681-
"Generation of SMT formula for shift expression: " + shift.pretty());
712+
return convert_to_smt_shift(
713+
smt_bit_vector_theoryt::arithmetic_shift_right, *right_arith_shift);
682714
}
715+
UNIMPLEMENTED_FEATURE(
716+
"Generation of SMT formula for shift expression: " + shift.pretty());
683717
}
684718

685719
static smt_termt convert_expr_to_smt(const with_exprt &with)
@@ -733,6 +767,11 @@ static smt_termt convert_expr_to_smt(const extractbit_exprt &extract_bit)
733767

734768
static smt_termt convert_expr_to_smt(const extractbits_exprt &extract_bits)
735769
{
770+
const smt_termt from = convert_expr_to_smt(extract_bits.src());
771+
const auto upper_value = numeric_cast<std::size_t>(extract_bits.upper());
772+
const auto lower_value = numeric_cast<std::size_t>(extract_bits.lower());
773+
if(upper_value && lower_value)
774+
return smt_bit_vector_theoryt::extract(*upper_value, *lower_value)(from);
736775
UNIMPLEMENTED_FEATURE(
737776
"Generation of SMT formula for extract bits expression: " +
738777
extract_bits.pretty());

unit/solvers/smt2_incremental/convert_expr_to_smt.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,105 @@ SCENARIO(
819819
}
820820
}
821821

822+
TEST_CASE(
823+
"expr to smt conversion for shifts of mismatched operands.",
824+
"[core][smt2_incremental]")
825+
{
826+
using make_typet = std::function<typet(std::size_t)>;
827+
const make_typet make_unsigned = constructor_oft<unsignedbv_typet>{};
828+
const make_typet make_signed = constructor_oft<signedbv_typet>{};
829+
using make_extensiont =
830+
std::function<std::function<smt_termt(smt_termt)>(std::size_t)>;
831+
const make_extensiont zero_extend = smt_bit_vector_theoryt::zero_extend;
832+
const make_extensiont sign_extend = smt_bit_vector_theoryt::sign_extend;
833+
std::string type_description;
834+
make_typet make_type;
835+
make_extensiont make_extension;
836+
using type_rowt = std::tuple<std::string, make_typet, make_extensiont>;
837+
std::tie(type_description, make_type, make_extension) = GENERATE_REF(
838+
type_rowt{"Unsigned operands.", make_unsigned, zero_extend},
839+
type_rowt{"Signed operands.", make_signed, sign_extend});
840+
SECTION(type_description)
841+
{
842+
using make_shift_exprt = std::function<exprt(exprt, exprt)>;
843+
const make_shift_exprt left_shift_expr = constructor_of<shl_exprt>();
844+
const make_shift_exprt arithmetic_right_shift_expr =
845+
constructor_of<ashr_exprt>();
846+
const make_shift_exprt logical_right_shift_expr =
847+
constructor_of<lshr_exprt>();
848+
using make_shift_termt = std::function<smt_termt(smt_termt, smt_termt)>;
849+
const make_shift_termt left_shift_term = smt_bit_vector_theoryt::shift_left;
850+
const make_shift_termt arithmetic_right_shift_term =
851+
smt_bit_vector_theoryt::arithmetic_shift_right;
852+
const make_shift_termt logical_right_shift_term =
853+
smt_bit_vector_theoryt::logical_shift_right;
854+
std::string shift_description;
855+
make_shift_exprt make_shift_expr;
856+
make_shift_termt make_shift_term;
857+
using shift_rowt =
858+
std::tuple<std::string, make_shift_exprt, make_shift_termt>;
859+
std::tie(shift_description, make_shift_expr, make_shift_term) =
860+
GENERATE_REF(
861+
shift_rowt{"Left shift.", left_shift_expr, left_shift_term},
862+
shift_rowt{
863+
"Arithmetic right shift.",
864+
arithmetic_right_shift_expr,
865+
arithmetic_right_shift_term},
866+
shift_rowt{
867+
"Logical right shift.",
868+
logical_right_shift_expr,
869+
logical_right_shift_term});
870+
SECTION(shift_description)
871+
{
872+
SECTION("Wider left hand side")
873+
{
874+
const exprt input = make_shift_expr(
875+
symbol_exprt{"foo", make_type(32)},
876+
symbol_exprt{"bar", make_type(8)});
877+
INFO("Input expr: " + input.pretty(2, 0));
878+
const smt_termt expected_result = make_shift_term(
879+
smt_identifier_termt{"foo", smt_bit_vector_sortt{32}},
880+
make_extension(24)(
881+
smt_identifier_termt{"bar", smt_bit_vector_sortt{8}}));
882+
CHECK(convert_expr_to_smt(input) == expected_result);
883+
}
884+
SECTION("Wider right hand side")
885+
{
886+
const exprt input = make_shift_expr(
887+
symbol_exprt{"foo", make_type(8)},
888+
symbol_exprt{"bar", make_type(32)});
889+
INFO("Input expr: " + input.pretty(2, 0));
890+
const smt_termt expected_result = make_shift_term(
891+
make_extension(24)(
892+
smt_identifier_termt{"foo", smt_bit_vector_sortt{8}}),
893+
smt_identifier_termt{"bar", smt_bit_vector_sortt{32}});
894+
CHECK(convert_expr_to_smt(input) == expected_result);
895+
}
896+
}
897+
}
898+
}
899+
900+
TEST_CASE(
901+
"expr to smt conversion for extract bits expressions",
902+
"[core][smt2_incremental]")
903+
{
904+
const typet operand_type = unsignedbv_typet{8};
905+
const exprt input = extractbits_exprt{
906+
symbol_exprt{"foo", operand_type},
907+
from_integer(4, operand_type),
908+
from_integer(2, operand_type),
909+
unsignedbv_typet{3}};
910+
const smt_termt expected_result = smt_bit_vector_theoryt::extract(4, 2)(
911+
smt_identifier_termt{"foo", smt_bit_vector_sortt{8}});
912+
CHECK(convert_expr_to_smt(input) == expected_result);
913+
const cbmc_invariants_should_throwt invariants_throw;
914+
CHECK_THROWS(convert_expr_to_smt(extractbits_exprt{
915+
symbol_exprt{"foo", operand_type},
916+
symbol_exprt{"bar", operand_type},
917+
symbol_exprt{"bar", operand_type},
918+
unsignedbv_typet{3}}));
919+
}
920+
822921
TEST_CASE("expr to smt conversion for type casts", "[core][smt2_incremental]")
823922
{
824923
const symbol_exprt bool_expr{"foo", bool_typet{}};

0 commit comments

Comments
 (0)