Skip to content

Commit 7157df5

Browse files
Merge pull request #6689 from thomasspriggs/tas/smt_casting
Add conversion of cast expressions to SMT terms for incremental SMT solving
2 parents 37f122c + 307a04d commit 7157df5

File tree

4 files changed

+252
-5
lines changed

4 files changed

+252
-5
lines changed

src/solvers/smt2_incremental/convert_expr_to_smt.cpp

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
// Author: Diffblue Ltd.
22

3-
#include <solvers/smt2_incremental/convert_expr_to_smt.h>
4-
5-
#include <solvers/prop/literal_expr.h>
6-
#include <solvers/smt2_incremental/smt_bit_vector_theory.h>
7-
#include <solvers/smt2_incremental/smt_core_theory.h>
83
#include <util/arith_tools.h>
94
#include <util/bitvector_expr.h>
105
#include <util/byte_operators.h>
6+
#include <util/c_types.h>
117
#include <util/expr.h>
128
#include <util/expr_cast.h>
139
#include <util/floatbv_expr.h>
@@ -18,6 +14,11 @@
1814
#include <util/std_expr.h>
1915
#include <util/string_constant.h>
2016

17+
#include <solvers/prop/literal_expr.h>
18+
#include <solvers/smt2_incremental/convert_expr_to_smt.h>
19+
#include <solvers/smt2_incremental/smt_bit_vector_theory.h>
20+
#include <solvers/smt2_incremental/smt_core_theory.h>
21+
2122
#include <functional>
2223
#include <numeric>
2324

@@ -98,8 +99,116 @@ static smt_termt convert_expr_to_smt(const nondet_symbol_exprt &nondet_symbol)
9899
nondet_symbol.pretty());
99100
}
100101

102+
/// \brief Makes a term which is true if \p input is not 0 / false.
103+
static smt_termt make_not_zero(const smt_termt &input, const typet &source_type)
104+
{
105+
if(input.get_sort().cast<smt_bool_sortt>())
106+
return input;
107+
return smt_core_theoryt::distinct(
108+
input, convert_expr_to_smt(from_integer(0, source_type)));
109+
}
110+
111+
/// \brief Returns a cast to C bool expressed in smt terms.
112+
static smt_termt convert_c_bool_cast(
113+
const smt_termt &from_term,
114+
const typet &from_type,
115+
const bitvector_typet &to_type)
116+
{
117+
const std::size_t c_bool_width = to_type.get_width();
118+
return smt_core_theoryt::if_then_else(
119+
make_not_zero(from_term, from_type),
120+
smt_bit_vector_constant_termt{1, c_bool_width},
121+
smt_bit_vector_constant_termt{0, c_bool_width});
122+
}
123+
124+
static smt_termt make_bitvector_resize_cast(
125+
const smt_termt &from_term,
126+
const bitvector_typet &from_type,
127+
const bitvector_typet &to_type)
128+
{
129+
if(const auto to_fixedbv_type = type_try_dynamic_cast<fixedbv_typet>(to_type))
130+
{
131+
UNIMPLEMENTED_FEATURE(
132+
"Generation of SMT formula for type cast to fixed-point bitvector "
133+
"type: " +
134+
to_type.pretty());
135+
}
136+
if(const auto to_floatbv_type = type_try_dynamic_cast<floatbv_typet>(to_type))
137+
{
138+
UNIMPLEMENTED_FEATURE(
139+
"Generation of SMT formula for type cast to floating-point bitvector "
140+
"type: " +
141+
to_type.pretty());
142+
}
143+
const std::size_t from_width = from_type.get_width();
144+
const std::size_t to_width = to_type.get_width();
145+
if(to_width == from_width)
146+
return from_term;
147+
if(to_width < from_width)
148+
return smt_bit_vector_theoryt::extract(to_width - 1, 0)(from_term);
149+
const std::size_t extension_size = to_width - from_width;
150+
if(can_cast_type<signedbv_typet>(from_type))
151+
return smt_bit_vector_theoryt::sign_extend(extension_size)(from_term);
152+
else
153+
return smt_bit_vector_theoryt::zero_extend(extension_size)(from_term);
154+
}
155+
156+
struct sort_based_cast_to_bit_vector_convertert final
157+
: public smt_sort_const_downcast_visitort
158+
{
159+
const smt_termt &from_term;
160+
const typet &from_type;
161+
const bitvector_typet &to_type;
162+
optionalt<smt_termt> result;
163+
164+
sort_based_cast_to_bit_vector_convertert(
165+
const smt_termt &from_term,
166+
const typet &from_type,
167+
const bitvector_typet &to_type)
168+
: from_term{from_term}, from_type{from_type}, to_type{to_type}
169+
{
170+
}
171+
172+
void visit(const smt_bool_sortt &) override
173+
{
174+
result = convert_c_bool_cast(
175+
from_term, from_type, c_bool_typet{to_type.get_width()});
176+
}
177+
178+
void visit(const smt_bit_vector_sortt &) override
179+
{
180+
if(const auto bitvector = type_try_dynamic_cast<bitvector_typet>(from_type))
181+
result = make_bitvector_resize_cast(from_term, *bitvector, to_type);
182+
else
183+
UNIMPLEMENTED_FEATURE(
184+
"Generation of SMT formula for type cast to bit vector from type: " +
185+
from_type.pretty());
186+
}
187+
};
188+
189+
static smt_termt convert_bit_vector_cast(
190+
const smt_termt &from_term,
191+
const typet &from_type,
192+
const bitvector_typet &to_type)
193+
{
194+
sort_based_cast_to_bit_vector_convertert converter{
195+
from_term, from_type, to_type};
196+
from_term.get_sort().accept(converter);
197+
POSTCONDITION(converter.result);
198+
return *converter.result;
199+
}
200+
101201
static smt_termt convert_expr_to_smt(const typecast_exprt &cast)
102202
{
203+
const auto from_term = convert_expr_to_smt(cast.op());
204+
const typet &from_type = cast.op().type();
205+
const typet &to_type = cast.type();
206+
if(const auto bool_type = type_try_dynamic_cast<bool_typet>(to_type))
207+
return make_not_zero(from_term, cast.op().type());
208+
if(const auto c_bool_type = type_try_dynamic_cast<c_bool_typet>(to_type))
209+
return convert_c_bool_cast(from_term, from_type, *c_bool_type);
210+
if(const auto bit_vector = type_try_dynamic_cast<bitvector_typet>(to_type))
211+
return convert_bit_vector_cast(from_term, from_type, *bit_vector);
103212
UNIMPLEMENTED_FEATURE(
104213
"Generation of SMT formula for type cast expression: " + cast.pretty());
105214
}

src/solvers/smt2_incremental/smt_bit_vector_theory.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ class smt_bit_vector_theoryt
2525
std::vector<smt_indext> indices() const;
2626
void validate(const smt_termt &operand) const;
2727
};
28+
/// \brief
29+
/// Makes a factory for extract function applications.
30+
/// \param i
31+
/// Index of the highest bit to be included in the resulting bit vector.
32+
/// \param j
33+
/// Index of the lowest bit to be included in the resulting bit vector.
34+
/// \note
35+
/// Bit vectors are zero indexed. So the lowest bit index is zero and the
36+
/// largest index is the size of the bit vector minus one.
2837
static smt_function_application_termt::factoryt<extractt>
2938
extract(std::size_t i, std::size_t j);
3039

src/solvers/smt2_incremental/smt_core_theory.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class smt_core_theoryt
6262
static smt_sortt return_sort(const smt_termt &lhs, const smt_termt &rhs);
6363
static void validate(const smt_termt &lhs, const smt_termt &rhs);
6464
};
65+
/// Makes applications of the function which returns true iff its two
66+
/// arguments are not identical.
6567
static const smt_function_application_termt::factoryt<distinctt> distinct;
6668

6769
struct if_then_elset final

unit/solvers/smt2_incremental/convert_expr_to_smt.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
#include <util/arith_tools.h>
44
#include <util/bitvector_expr.h>
55
#include <util/bitvector_types.h>
6+
#include <util/c_types.h>
7+
#include <util/config.h>
8+
#include <util/constructor_of.h>
69
#include <util/format.h>
710
#include <util/std_expr.h>
811

912
#include <solvers/smt2_incremental/convert_expr_to_smt.h>
1013
#include <solvers/smt2_incremental/smt_bit_vector_theory.h>
1114
#include <solvers/smt2_incremental/smt_core_theory.h>
1215
#include <solvers/smt2_incremental/smt_terms.h>
16+
#include <solvers/smt2_incremental/smt_to_smt2_string.h>
1317
#include <testing-utils/use_catch.h>
1418

1519
TEST_CASE("\"typet\" to smt sort conversion", "[core][smt2_incremental]")
@@ -814,3 +818,126 @@ SCENARIO(
814818
}
815819
}
816820
}
821+
822+
TEST_CASE("expr to smt conversion for type casts", "[core][smt2_incremental]")
823+
{
824+
const symbol_exprt bool_expr{"foo", bool_typet{}};
825+
const smt_termt bool_term = smt_identifier_termt{"foo", smt_bool_sortt{}};
826+
const symbol_exprt bv_expr{"bar", signedbv_typet(12)};
827+
const smt_termt bv_term =
828+
smt_identifier_termt{"bar", smt_bit_vector_sortt{12}};
829+
SECTION("Casts to bool")
830+
{
831+
CHECK(
832+
convert_expr_to_smt(typecast_exprt{bool_expr, bool_typet{}}) ==
833+
bool_term);
834+
CHECK(
835+
convert_expr_to_smt(typecast_exprt{bv_expr, bool_typet{}}) ==
836+
smt_core_theoryt::distinct(
837+
bv_term, smt_bit_vector_constant_termt{0, 12}));
838+
}
839+
SECTION("Casts to C bool")
840+
{
841+
// The config lines are necessary because when we do casting to C bool the
842+
// bit width depends on the configuration.
843+
config.ansi_c.mode = configt::ansi_ct::flavourt::GCC;
844+
config.ansi_c.set_arch_spec_i386();
845+
const std::size_t c_bool_width = config.ansi_c.bool_width;
846+
const smt_bit_vector_constant_termt c_true{1, c_bool_width};
847+
const smt_bit_vector_constant_termt c_false{0, c_bool_width};
848+
SECTION("from bool")
849+
{
850+
const auto cast_bool =
851+
convert_expr_to_smt(typecast_exprt{bool_expr, c_bool_type()});
852+
const auto expected_bool_conversion =
853+
smt_core_theoryt::if_then_else(bool_term, c_true, c_false);
854+
CHECK(cast_bool == expected_bool_conversion);
855+
}
856+
SECTION("from bit vector")
857+
{
858+
const auto cast_bit_vector =
859+
convert_expr_to_smt(typecast_exprt{bv_expr, c_bool_type()});
860+
const auto expected_bit_vector_conversion =
861+
smt_core_theoryt::if_then_else(
862+
smt_core_theoryt::distinct(
863+
bv_term, smt_bit_vector_constant_termt{0, 12}),
864+
c_true,
865+
c_false);
866+
CHECK(cast_bit_vector == expected_bit_vector_conversion);
867+
}
868+
}
869+
SECTION("Casts to bit vector")
870+
{
871+
SECTION("Matched width casts")
872+
{
873+
typet from_type, to_type;
874+
using rowt = std::pair<typet, typet>;
875+
std::tie(from_type, to_type) = GENERATE(
876+
rowt{unsignedbv_typet{8}, unsignedbv_typet{8}},
877+
rowt{unsignedbv_typet{8}, signedbv_typet{8}},
878+
rowt{signedbv_typet{8}, unsignedbv_typet{8}});
879+
CHECK(
880+
convert_expr_to_smt(
881+
typecast_exprt{from_integer(1, from_type), to_type}) ==
882+
smt_bit_vector_constant_termt{1, 8});
883+
}
884+
SECTION("Narrowing casts")
885+
{
886+
CHECK(
887+
convert_expr_to_smt(typecast_exprt{bv_expr, signedbv_typet{8}}) ==
888+
smt_bit_vector_theoryt::extract(7, 0)(bv_term));
889+
CHECK(
890+
convert_expr_to_smt(typecast_exprt{
891+
from_integer(42, unsignedbv_typet{32}), unsignedbv_typet{16}}) ==
892+
smt_bit_vector_theoryt::extract(15, 0)(
893+
smt_bit_vector_constant_termt{42, 32}));
894+
}
895+
SECTION("Widening casts")
896+
{
897+
std::size_t from_width, to_width, extension_width;
898+
using size_rowt = std::tuple<std::size_t, std::size_t, std::size_t>;
899+
std::tie(from_width, to_width, extension_width) = GENERATE(
900+
size_rowt{8, 64, 56}, size_rowt{16, 32, 16}, size_rowt{16, 128, 112});
901+
PRECONDITION(from_width < to_width);
902+
PRECONDITION(to_width - from_width == extension_width);
903+
using make_typet = std::function<typet(std::size_t)>;
904+
const make_typet make_unsigned = constructor_oft<unsignedbv_typet>{};
905+
const make_typet make_signed = constructor_oft<signedbv_typet>{};
906+
using make_extensiont =
907+
std::function<std::function<smt_termt(smt_termt)>(std::size_t)>;
908+
const make_extensiont zero_extend = smt_bit_vector_theoryt::zero_extend;
909+
const make_extensiont sign_extend = smt_bit_vector_theoryt::sign_extend;
910+
make_typet make_source_type, make_destination_type;
911+
make_extensiont make_extension;
912+
using types_rowt = std::tuple<make_typet, make_typet, make_extensiont>;
913+
std::tie(make_source_type, make_destination_type, make_extension) =
914+
GENERATE_REF(
915+
types_rowt{make_unsigned, make_unsigned, zero_extend},
916+
types_rowt{make_signed, make_signed, sign_extend},
917+
types_rowt{make_signed, make_unsigned, sign_extend},
918+
types_rowt{make_unsigned, make_signed, zero_extend});
919+
const typecast_exprt cast{
920+
from_integer(42, make_source_type(from_width)),
921+
make_destination_type(to_width)};
922+
const smt_termt expected_term = make_extension(extension_width)(
923+
smt_bit_vector_constant_termt{42, from_width});
924+
CHECK(convert_expr_to_smt(cast) == expected_term);
925+
}
926+
SECTION("from bool")
927+
{
928+
const exprt from_expr = GENERATE(
929+
exprt{symbol_exprt{"baz", bool_typet{}}},
930+
exprt{true_exprt{}},
931+
exprt{false_exprt{}});
932+
const smt_termt from_term = convert_expr_to_smt(from_expr);
933+
const std::size_t width = GENERATE(1, 8, 16, 32, 64);
934+
const typecast_exprt cast{from_expr, bitvector_typet{ID_bv, width}};
935+
CHECK(
936+
convert_expr_to_smt(cast) ==
937+
smt_core_theoryt::if_then_else(
938+
from_term,
939+
smt_bit_vector_constant_termt{1, width},
940+
smt_bit_vector_constant_termt{0, width}));
941+
}
942+
}
943+
}

0 commit comments

Comments
 (0)