Skip to content

Add get value response validation for smt bv constant descriptors #6879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 63 additions & 14 deletions src/solvers/smt2_incremental/smt_response_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
/// `response_or_errort` in the case where the parse tree is of that type or
/// an empty optional otherwise.

#include <solvers/smt2_incremental/smt_response_validation.h>
#include "smt_response_validation.h"

#include <util/arith_tools.h>
#include <util/mp_arith.h>
#include <util/range.h>

Expand Down Expand Up @@ -190,15 +191,33 @@ static bool all_subs_are_pairs(const irept &parse_tree)
[](const irept &sub) { return sub.get_sub().size() == 2; });
}

static response_or_errort<irep_idt>
validate_smt_identifier(const irept &parse_tree)
/// Checks for valid bit vector constants of the form `(_ bv(value) (width))`
/// for example - `(_ bv4 64)`.
static optionalt<smt_termt>
valid_smt_indexed_bit_vector(const irept &parse_tree)
{
if(!parse_tree.get_sub().empty() || parse_tree.id().empty())
{
return response_or_errort<irep_idt>(
"Expected identifier, found - \"" + print_parse_tree(parse_tree) + "\".");
}
return response_or_errort<irep_idt>(parse_tree.id());
if(parse_tree.get_sub().size() != 3)
return {};
if(parse_tree.get_sub().at(0).id() != "_")
return {};
const auto value_string = id2string(parse_tree.get_sub().at(1).id());
std::smatch match_results;
static const std::regex bv_value_regex{R"(^bv(\d+)$)", std::regex::optimize};
if(!std::regex_search(value_string, match_results, bv_value_regex))
return {};
INVARIANT(
match_results.size() == 2,
"Match results should include digits sub-expression if regex is matched.");
const std::string value_digits = match_results[1];
const auto value = string2integer(value_digits);
const auto bit_width_string = id2string(parse_tree.get_sub().at(2).id());
const auto bit_width =
numeric_cast_v<std::size_t>(string2integer(bit_width_string));
if(bit_width == 0)
return {};
if(value >= power(mp_integer{2}, bit_width))
return {};
return smt_bit_vector_constant_termt{value, bit_width};
}

static optionalt<smt_termt> valid_smt_bool(const irept &parse_tree)
Expand Down Expand Up @@ -229,7 +248,7 @@ static optionalt<smt_termt> valid_smt_hex(const std::string &text)
if(!std::regex_match(text, hex_format))
return {};
const std::string hex{text.begin() + 2, text.end()};
// SMT-LIB 2 allows hex characters to be upper of lower case, but they should
// SMT-LIB 2 allows hex characters to be upper or lower case, but they should
// be upper case for mp_integer.
const mp_integer value =
string2integer(make_range(hex).map<std::function<int(int)>>(toupper), 16);
Expand All @@ -240,6 +259,8 @@ static optionalt<smt_termt> valid_smt_hex(const std::string &text)
static optionalt<smt_termt>
valid_smt_bit_vector_constant(const irept &parse_tree)
{
if(const auto indexed = valid_smt_indexed_bit_vector(parse_tree))
return *indexed;
if(!parse_tree.get_sub().empty() || parse_tree.id().empty())
return {};
const auto value_string = id2string(parse_tree.id());
Expand All @@ -250,24 +271,52 @@ valid_smt_bit_vector_constant(const irept &parse_tree)
return {};
}

static response_or_errort<smt_termt> validate_term(const irept &parse_tree)
static optionalt<smt_termt> valid_term(const irept &parse_tree)
{
if(const auto smt_bool = valid_smt_bool(parse_tree))
return response_or_errort<smt_termt>{*smt_bool};
return {*smt_bool};
if(const auto bit_vector_constant = valid_smt_bit_vector_constant(parse_tree))
return response_or_errort<smt_termt>{*bit_vector_constant};
return {*bit_vector_constant};
return {};
}

static response_or_errort<smt_termt> validate_term(const irept &parse_tree)
{
if(const auto term = valid_term(parse_tree))
return response_or_errort<smt_termt>{*term};
return response_or_errort<smt_termt>{"Unrecognised SMT term - \"" +
print_parse_tree(parse_tree) + "\"."};
}

static response_or_errort<smt_termt>
validate_smt_descriptor(const irept &parse_tree, const smt_sortt &sort)
{
if(const auto term = valid_term(parse_tree))
return response_or_errort<smt_termt>{*term};
const auto id = parse_tree.id();
if(!id.empty())
return response_or_errort<smt_termt>{smt_identifier_termt{id, sort}};
return response_or_errort<smt_termt>{
"Expected descriptor SMT term, found - \"" + print_parse_tree(parse_tree) +
"\"."};
}

static response_or_errort<smt_get_value_responset::valuation_pairt>
validate_valuation_pair(const irept &pair_parse_tree)
{
PRECONDITION(pair_parse_tree.get_sub().size() == 2);
const auto &descriptor = pair_parse_tree.get_sub()[0];
const auto &value = pair_parse_tree.get_sub()[1];
const response_or_errort<smt_termt> value_validation = validate_term(value);
if(const auto value_errors = value_validation.get_if_error())
{
return response_or_errort<smt_get_value_responset::valuation_pairt>{
*value_errors};
}
const smt_termt value_term = *value_validation.get_if_valid();
return validation_propagating<smt_get_value_responset::valuation_pairt>(
validate_smt_identifier(descriptor), validate_term(value));
validate_smt_descriptor(descriptor, value_term.get_sort()),
validate_term(value));
}

/// \returns: A response or error in the case where the parse tree appears to be
Expand Down
103 changes: 86 additions & 17 deletions unit/solvers/smt2_incremental/smt_response_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,90 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]")
}
SECTION("Bit vector sorted values.")
{
const response_or_errort<smt_responset> response_255 =
validate_smt_response(*smt2irep("((a #xff))").parsed_output);
CHECK(
*response_255.get_if_valid() ==
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
smt_identifier_termt{"a", smt_bit_vector_sortt{8}},
smt_bit_vector_constant_termt{255, 8}}}});
const response_or_errort<smt_responset> response_42 =
validate_smt_response(*smt2irep("((a #b00101010))").parsed_output);
CHECK(
*response_42.get_if_valid() ==
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
smt_identifier_termt{"a", smt_bit_vector_sortt{8}},
smt_bit_vector_constant_termt{42, 8}}}});
SECTION("Hex value")
{
const response_or_errort<smt_responset> response_255 =
validate_smt_response(*smt2irep("((a #xff))").parsed_output);
CHECK(
*response_255.get_if_valid() ==
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
smt_identifier_termt{"a", smt_bit_vector_sortt{8}},
smt_bit_vector_constant_termt{255, 8}}}});
}
SECTION("Binary value")
{
const response_or_errort<smt_responset> response_42 =
validate_smt_response(*smt2irep("((a #b00101010))").parsed_output);
CHECK(
*response_42.get_if_valid() ==
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
smt_identifier_termt{"a", smt_bit_vector_sortt{8}},
smt_bit_vector_constant_termt{42, 8}}}});
}
SECTION("Descriptors which are bit vector constants")
{
const response_or_errort<smt_responset> response_descriptor =
validate_smt_response(*smt2irep("(((_ bv255 8) #x2A))").parsed_output);
CHECK(
*response_descriptor.get_if_valid() ==
smt_get_value_responset{{smt_get_value_responset::valuation_pairt{
smt_bit_vector_constant_termt{255, 8},
smt_bit_vector_constant_termt{42, 8}}}});
SECTION("Invalid bit vector constants")
{
SECTION("Value too large for width")
{
const response_or_errort<smt_responset> pair_value_response =
validate_smt_response(
*smt2irep("(((_ bv256 8) #xff))").parsed_output);
CHECK(
*pair_value_response.get_if_error() ==
std::vector<std::string>{
"Expected descriptor SMT term, found - \"\n"
"0: _\n"
"1: bv256\n"
"2: 8\"."});
}
SECTION("Value missing bv prefix.")
{
const response_or_errort<smt_responset> pair_value_response =
validate_smt_response(*smt2irep("(((_ 42 8) #xff))").parsed_output);
CHECK(
*pair_value_response.get_if_error() ==
std::vector<std::string>{
"Expected descriptor SMT term, found - \"\n"
"0: _\n"
"1: 42\n"
"2: 8\"."});
}
SECTION("Hex value.")
{
const response_or_errort<smt_responset> pair_value_response =
validate_smt_response(
*smt2irep("(((_ bv2A 8) #xff))").parsed_output);
CHECK(
*pair_value_response.get_if_error() ==
std::vector<std::string>{
"Expected descriptor SMT term, found - \"\n"
"0: _\n"
"1: bv2A\n"
"2: 8\"."});
}
SECTION("Zero width.")
{
const response_or_errort<smt_responset> pair_value_response =
validate_smt_response(
*smt2irep("(((_ bv0 0) #xff))").parsed_output);
CHECK(
*pair_value_response.get_if_error() ==
std::vector<std::string>{
"Expected descriptor SMT term, found - \"\n"
"0: _\n"
"1: bv0\n"
"2: 0\"."});
}
}
}
}
SECTION("Multiple valuation pairs.")
{
Expand Down Expand Up @@ -174,12 +244,11 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]")
validate_smt_response(*smt2irep("((() true))").parsed_output);
CHECK(
*empty_descriptor_response.get_if_error() ==
std::vector<std::string>{"Expected identifier, found - \"\"."});
std::vector<std::string>{"Expected descriptor SMT term, found - \"\"."});
const response_or_errort<smt_responset> empty_pair =
validate_smt_response(*smt2irep("((() ())))").parsed_output);
CHECK(
*empty_pair.get_if_error() ==
std::vector<std::string>{"Expected identifier, found - \"\".",
"Unrecognised SMT term - \"\"."});
std::vector<std::string>{"Unrecognised SMT term - \"\"."});
}
}