diff --git a/src/solvers/smt2_incremental/smt_response_validation.cpp b/src/solvers/smt2_incremental/smt_response_validation.cpp index 143e0f03099..9634eb79339 100644 --- a/src/solvers/smt2_incremental/smt_response_validation.cpp +++ b/src/solvers/smt2_incremental/smt_response_validation.cpp @@ -15,8 +15,9 @@ /// `response_or_errort` in the case where the parse tree is of that type or /// an empty optional otherwise. -#include +#include "smt_response_validation.h" +#include #include #include @@ -190,15 +191,33 @@ static bool all_subs_are_pairs(const irept &parse_tree) [](const irept &sub) { return sub.get_sub().size() == 2; }); } -static response_or_errort -validate_smt_identifier(const irept &parse_tree) +/// Checks for valid bit vector constants of the form `(_ bv(value) (width))` +/// for example - `(_ bv4 64)`. +static optionalt +valid_smt_indexed_bit_vector(const irept &parse_tree) { - if(!parse_tree.get_sub().empty() || parse_tree.id().empty()) - { - return response_or_errort( - "Expected identifier, found - \"" + print_parse_tree(parse_tree) + "\"."); - } - return response_or_errort(parse_tree.id()); + if(parse_tree.get_sub().size() != 3) + return {}; + if(parse_tree.get_sub().at(0).id() != "_") + return {}; + const auto value_string = id2string(parse_tree.get_sub().at(1).id()); + std::smatch match_results; + static const std::regex bv_value_regex{R"(^bv(\d+)$)", std::regex::optimize}; + if(!std::regex_search(value_string, match_results, bv_value_regex)) + return {}; + INVARIANT( + match_results.size() == 2, + "Match results should include digits sub-expression if regex is matched."); + const std::string value_digits = match_results[1]; + const auto value = string2integer(value_digits); + const auto bit_width_string = id2string(parse_tree.get_sub().at(2).id()); + const auto bit_width = + numeric_cast_v(string2integer(bit_width_string)); + if(bit_width == 0) + return {}; + if(value >= power(mp_integer{2}, bit_width)) + return {}; + return smt_bit_vector_constant_termt{value, bit_width}; } static optionalt valid_smt_bool(const irept &parse_tree) @@ -229,7 +248,7 @@ static optionalt valid_smt_hex(const std::string &text) if(!std::regex_match(text, hex_format)) return {}; const std::string hex{text.begin() + 2, text.end()}; - // SMT-LIB 2 allows hex characters to be upper of lower case, but they should + // SMT-LIB 2 allows hex characters to be upper or lower case, but they should // be upper case for mp_integer. const mp_integer value = string2integer(make_range(hex).map>(toupper), 16); @@ -240,6 +259,8 @@ static optionalt valid_smt_hex(const std::string &text) static optionalt valid_smt_bit_vector_constant(const irept &parse_tree) { + if(const auto indexed = valid_smt_indexed_bit_vector(parse_tree)) + return *indexed; if(!parse_tree.get_sub().empty() || parse_tree.id().empty()) return {}; const auto value_string = id2string(parse_tree.id()); @@ -250,24 +271,52 @@ valid_smt_bit_vector_constant(const irept &parse_tree) return {}; } -static response_or_errort validate_term(const irept &parse_tree) +static optionalt valid_term(const irept &parse_tree) { if(const auto smt_bool = valid_smt_bool(parse_tree)) - return response_or_errort{*smt_bool}; + return {*smt_bool}; if(const auto bit_vector_constant = valid_smt_bit_vector_constant(parse_tree)) - return response_or_errort{*bit_vector_constant}; + return {*bit_vector_constant}; + return {}; +} + +static response_or_errort validate_term(const irept &parse_tree) +{ + if(const auto term = valid_term(parse_tree)) + return response_or_errort{*term}; return response_or_errort{"Unrecognised SMT term - \"" + print_parse_tree(parse_tree) + "\"."}; } +static response_or_errort +validate_smt_descriptor(const irept &parse_tree, const smt_sortt &sort) +{ + if(const auto term = valid_term(parse_tree)) + return response_or_errort{*term}; + const auto id = parse_tree.id(); + if(!id.empty()) + return response_or_errort{smt_identifier_termt{id, sort}}; + return response_or_errort{ + "Expected descriptor SMT term, found - \"" + print_parse_tree(parse_tree) + + "\"."}; +} + static response_or_errort validate_valuation_pair(const irept &pair_parse_tree) { PRECONDITION(pair_parse_tree.get_sub().size() == 2); const auto &descriptor = pair_parse_tree.get_sub()[0]; const auto &value = pair_parse_tree.get_sub()[1]; + const response_or_errort value_validation = validate_term(value); + if(const auto value_errors = value_validation.get_if_error()) + { + return response_or_errort{ + *value_errors}; + } + const smt_termt value_term = *value_validation.get_if_valid(); return validation_propagating( - validate_smt_identifier(descriptor), validate_term(value)); + validate_smt_descriptor(descriptor, value_term.get_sort()), + validate_term(value)); } /// \returns: A response or error in the case where the parse tree appears to be diff --git a/unit/solvers/smt2_incremental/smt_response_validation.cpp b/unit/solvers/smt2_incremental/smt_response_validation.cpp index c249abb4c87..0892b31ed07 100644 --- a/unit/solvers/smt2_incremental/smt_response_validation.cpp +++ b/unit/solvers/smt2_incremental/smt_response_validation.cpp @@ -117,20 +117,90 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") } SECTION("Bit vector sorted values.") { - const response_or_errort response_255 = - validate_smt_response(*smt2irep("((a #xff))").parsed_output); - CHECK( - *response_255.get_if_valid() == - smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ - smt_identifier_termt{"a", smt_bit_vector_sortt{8}}, - smt_bit_vector_constant_termt{255, 8}}}}); - const response_or_errort response_42 = - validate_smt_response(*smt2irep("((a #b00101010))").parsed_output); - CHECK( - *response_42.get_if_valid() == - smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ - smt_identifier_termt{"a", smt_bit_vector_sortt{8}}, - smt_bit_vector_constant_termt{42, 8}}}}); + SECTION("Hex value") + { + const response_or_errort response_255 = + validate_smt_response(*smt2irep("((a #xff))").parsed_output); + CHECK( + *response_255.get_if_valid() == + smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ + smt_identifier_termt{"a", smt_bit_vector_sortt{8}}, + smt_bit_vector_constant_termt{255, 8}}}}); + } + SECTION("Binary value") + { + const response_or_errort response_42 = + validate_smt_response(*smt2irep("((a #b00101010))").parsed_output); + CHECK( + *response_42.get_if_valid() == + smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ + smt_identifier_termt{"a", smt_bit_vector_sortt{8}}, + smt_bit_vector_constant_termt{42, 8}}}}); + } + SECTION("Descriptors which are bit vector constants") + { + const response_or_errort response_descriptor = + validate_smt_response(*smt2irep("(((_ bv255 8) #x2A))").parsed_output); + CHECK( + *response_descriptor.get_if_valid() == + smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ + smt_bit_vector_constant_termt{255, 8}, + smt_bit_vector_constant_termt{42, 8}}}}); + SECTION("Invalid bit vector constants") + { + SECTION("Value too large for width") + { + const response_or_errort pair_value_response = + validate_smt_response( + *smt2irep("(((_ bv256 8) #xff))").parsed_output); + CHECK( + *pair_value_response.get_if_error() == + std::vector{ + "Expected descriptor SMT term, found - \"\n" + "0: _\n" + "1: bv256\n" + "2: 8\"."}); + } + SECTION("Value missing bv prefix.") + { + const response_or_errort pair_value_response = + validate_smt_response(*smt2irep("(((_ 42 8) #xff))").parsed_output); + CHECK( + *pair_value_response.get_if_error() == + std::vector{ + "Expected descriptor SMT term, found - \"\n" + "0: _\n" + "1: 42\n" + "2: 8\"."}); + } + SECTION("Hex value.") + { + const response_or_errort pair_value_response = + validate_smt_response( + *smt2irep("(((_ bv2A 8) #xff))").parsed_output); + CHECK( + *pair_value_response.get_if_error() == + std::vector{ + "Expected descriptor SMT term, found - \"\n" + "0: _\n" + "1: bv2A\n" + "2: 8\"."}); + } + SECTION("Zero width.") + { + const response_or_errort pair_value_response = + validate_smt_response( + *smt2irep("(((_ bv0 0) #xff))").parsed_output); + CHECK( + *pair_value_response.get_if_error() == + std::vector{ + "Expected descriptor SMT term, found - \"\n" + "0: _\n" + "1: bv0\n" + "2: 0\"."}); + } + } + } } SECTION("Multiple valuation pairs.") { @@ -174,12 +244,11 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") validate_smt_response(*smt2irep("((() true))").parsed_output); CHECK( *empty_descriptor_response.get_if_error() == - std::vector{"Expected identifier, found - \"\"."}); + std::vector{"Expected descriptor SMT term, found - \"\"."}); const response_or_errort empty_pair = validate_smt_response(*smt2irep("((() ())))").parsed_output); CHECK( *empty_pair.get_if_error() == - std::vector{"Expected identifier, found - \"\".", - "Unrecognised SMT term - \"\"."}); + std::vector{"Unrecognised SMT term - \"\"."}); } }