Skip to content

Add bit vector extract operation support to incremental SMT2 solving #6654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 11, 2022
Merged
1 change: 1 addition & 0 deletions src/solvers/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ SRC = $(BOOLEFORCE_SRC) \
smt2_incremental/smt_bit_vector_theory.cpp \
smt2_incremental/smt_commands.cpp \
smt2_incremental/smt_core_theory.cpp \
smt2_incremental/smt_index.cpp \
smt2_incremental/smt_logics.cpp \
smt2_incremental/smt_options.cpp \
smt2_incremental/smt_response_validation.cpp \
Expand Down
64 changes: 64 additions & 0 deletions src/solvers/smt2_incremental/smt_index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Author: Diffblue Ltd.

#include "smt_index.h"

// Define the irep_idts for kinds of index.
const irep_idt ID_smt_numeral_index{"smt_numeral_index"};
const irep_idt ID_smt_symbol_index{"smt_symbol_index"};

bool smt_indext::operator==(const smt_indext &other) const
{
return irept::operator==(other);
}

bool smt_indext::operator!=(const smt_indext &other) const
{
return !(*this == other);
}

template <>
const smt_numeral_indext *smt_indext::cast<smt_numeral_indext>() const &
{
return id() == ID_smt_numeral_index
? static_cast<const smt_numeral_indext *>(this)
: nullptr;
}

template <>
const smt_symbol_indext *smt_indext::cast<smt_symbol_indext>() const &
{
return id() == ID_smt_symbol_index
? static_cast<const smt_symbol_indext *>(this)
: nullptr;
}

void smt_indext::accept(smt_index_const_downcast_visitort &visitor) const
{
if(const auto numeral = this->cast<smt_numeral_indext>())
return visitor.visit(*numeral);
if(const auto symbol = this->cast<smt_symbol_indext>())
return visitor.visit(*symbol);
UNREACHABLE;
}

smt_numeral_indext::smt_numeral_indext(std::size_t value)
: smt_indext{ID_smt_numeral_index}
{
set_size_t(ID_value, value);
}

std::size_t smt_numeral_indext::value() const
{
return get_size_t(ID_value);
}

smt_symbol_indext::smt_symbol_indext(irep_idt identifier)
: smt_indext{ID_smt_symbol_index}
{
set(ID_identifier, identifier);
}

irep_idt smt_symbol_indext::identifier() const
{
return get(ID_identifier);
}
90 changes: 90 additions & 0 deletions src/solvers/smt2_incremental/smt_index.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Author: Diffblue Ltd.

#ifndef CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_INDEX_H
#define CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_INDEX_H

#include <util/irep.h>

class smt_index_const_downcast_visitort;

/// \brief
/// For implementation of indexed identifiers. See SMT-LIB Standard Version 2.6
/// section 3.3.
class smt_indext : protected irept
{
public:
// smt_indext does not support the notion of an empty / null state. Use
// optionalt<smt_indext> instead if an empty index is required.
smt_indext() = delete;

using irept::pretty;

bool operator==(const smt_indext &) const;
bool operator!=(const smt_indext &) const;

template <typename sub_classt>
const sub_classt *cast() const &;

void accept(smt_index_const_downcast_visitort &) const;

/// \brief Class for adding the ability to up and down cast smt_indext to and
/// from irept. These casts are required by other irept derived classes in
/// order to store instances of smt_termt inside them.
/// \tparam derivedt The type of class which derives from this class and from
/// irept.
template <typename derivedt>
class storert
{
protected:
storert();
static irept upcast(smt_indext index);
static const smt_indext &downcast(const irept &);
};

protected:
using irept::irept;
};

template <typename derivedt>
smt_indext::storert<derivedt>::storert()
{
static_assert(
std::is_base_of<irept, derivedt>::value &&
std::is_base_of<storert<derivedt>, derivedt>::value,
"Only irept based classes need to upcast smt_sortt to store it.");
}

template <typename derivedt>
irept smt_indext::storert<derivedt>::upcast(smt_indext index)
{
return static_cast<irept &&>(std::move(index));
}

template <typename derivedt>
const smt_indext &smt_indext::storert<derivedt>::downcast(const irept &irep)
{
return static_cast<const smt_indext &>(irep);
}

class smt_numeral_indext : public smt_indext
{
public:
explicit smt_numeral_indext(std::size_t value);
std::size_t value() const;
};

class smt_symbol_indext : public smt_indext
{
public:
explicit smt_symbol_indext(irep_idt identifier);
irep_idt identifier() const;
};

class smt_index_const_downcast_visitort
{
public:
virtual void visit(const smt_numeral_indext &) = 0;
virtual void visit(const smt_symbol_indext &) = 0;
};

#endif // CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_INDEX_H
19 changes: 18 additions & 1 deletion src/solvers/smt2_incremental/smt_terms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ static bool is_valid_smt_identifier(irep_idt identifier)
return std::regex_match(id2string(identifier), valid);
}

smt_identifier_termt::smt_identifier_termt(irep_idt identifier, smt_sortt sort)
smt_identifier_termt::smt_identifier_termt(
irep_idt identifier,
smt_sortt sort,
std::vector<smt_indext> indices)
: smt_termt(ID_smt_identifier_term, std::move(sort))
{
// The below invariant exists as a sanity check that the string being used for
Expand All @@ -67,13 +70,27 @@ smt_identifier_termt::smt_identifier_termt(irep_idt identifier, smt_sortt sort)
is_valid_smt_identifier(identifier),
R"(Identifiers may not contain | characters.)");
set(ID_identifier, identifier);
for(auto &index : indices)
{
get_sub().push_back(
smt_indext::storert<smt_identifier_termt>::upcast(std::move(index)));
}
}

irep_idt smt_identifier_termt::identifier() const
{
return get(ID_identifier);
}

std::vector<std::reference_wrapper<const smt_indext>>
smt_identifier_termt::indices() const
{
return make_range(get_sub()).map([](const irept &index) {
return std::cref(
smt_indext::storert<smt_identifier_termt>::downcast(index));
});
}

smt_bit_vector_constant_termt::smt_bit_vector_constant_termt(
const mp_integer &value,
smt_bit_vector_sortt sort)
Expand Down
21 changes: 18 additions & 3 deletions src/solvers/smt2_incremental/smt_terms.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
#ifndef CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_TERMS_H
#define CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_TERMS_H

#include <solvers/smt2_incremental/smt_sorts.h>
#include <util/irep.h>

#include <solvers/smt2_incremental/smt_index.h>
#include <solvers/smt2_incremental/smt_sorts.h>

#include <functional>

class BigInt;
Expand Down Expand Up @@ -77,7 +79,14 @@ class smt_bool_literal_termt : public smt_termt

/// Stores identifiers in unescaped and unquoted form. Any escaping or quoting
/// required should be performed during printing.
class smt_identifier_termt : public smt_termt
/// \details
/// The SMT-LIB standard Version 2.6 refers to "indexed" identifiers which have
/// 1 or more indices and "simple" identifiers which have no indicies. The
/// internal `smt_identifier_termt` class is used for both kinds of identifier
/// which are distinguished based on whether the collection of `indices` is
/// empty or not.
class smt_identifier_termt : public smt_termt,
private smt_indext::storert<smt_identifier_termt>
{
public:
/// \brief Constructs an identifier term with the given \p identifier and
Expand All @@ -91,8 +100,14 @@ class smt_identifier_termt : public smt_termt
/// \param sort: The sort which this term will have. All terms in our abstract
/// form must be sorted, even if those sorts are not printed in all
/// contexts.
smt_identifier_termt(irep_idt identifier, smt_sortt sort);
/// \param indices: This should be collection of indices for an indexed
/// identifier, or an empty collection for simple (non-indexed) identifiers.
smt_identifier_termt(
irep_idt identifier,
smt_sortt sort,
std::vector<smt_indext> indices = {});
irep_idt identifier() const;
std::vector<std::reference_wrapper<const smt_indext>> indices() const;
};

class smt_bit_vector_constant_termt : public smt_termt
Expand Down
1 change: 1 addition & 0 deletions unit/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ SRC += analyses/ai/ai.cpp \
solvers/smt2_incremental/smt_bit_vector_theory.cpp \
solvers/smt2_incremental/smt_commands.cpp \
solvers/smt2_incremental/smt_core_theory.cpp \
solvers/smt2_incremental/smt_index.cpp \
solvers/smt2_incremental/smt_response_validation.cpp \
solvers/smt2_incremental/smt_responses.cpp \
solvers/smt2_incremental/smt_sorts.cpp \
Expand Down
69 changes: 69 additions & 0 deletions unit/solvers/smt2_incremental/smt_index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Author: Diffblue Ltd.

#include <util/optional.h>

#include <solvers/smt2_incremental/smt_index.h>
#include <testing-utils/use_catch.h>

TEST_CASE("Test smt_indext.pretty is accessible.", "[core][smt2_incremental]")
{
const smt_indext index = smt_numeral_indext{42};
REQUIRE_FALSE(index.pretty().empty());
}

TEST_CASE("Test smt_index getters", "[core][smt2_incremental]")
{
SECTION("Numeral")
{
REQUIRE(smt_numeral_indext{42}.value() == 42);
}
SECTION("Symbol")
{
REQUIRE(smt_symbol_indext{"foo"}.identifier() == "foo");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you create a smt_symbol_indext{""}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then the constructed instant will contain the empty string. Are you suggesting an additional invariant?

}
}

TEST_CASE("Visiting smt_indext", "[core][smt2_incremental]")
{
class : public smt_index_const_downcast_visitort
{
public:
optionalt<std::size_t> numeral_visited{};
optionalt<irep_idt> symbol_visited{};

void visit(const smt_numeral_indext &numeral) override
{
numeral_visited = numeral.value();
}

void visit(const smt_symbol_indext &symbol) override
{
symbol_visited = symbol.identifier();
}
} visitor;
SECTION("numeral")
{
smt_numeral_indext{8}.accept(visitor);
REQUIRE(visitor.numeral_visited);
CHECK(*visitor.numeral_visited == 8);
CHECK_FALSE(visitor.symbol_visited);
}
SECTION("symbol")
{
smt_symbol_indext{"bar"}.accept(visitor);
CHECK_FALSE(visitor.numeral_visited);
REQUIRE(visitor.symbol_visited);
CHECK(*visitor.symbol_visited == "bar");
}
}

TEST_CASE("smt_index equality", "[core][smt2_incremental]")
{
const smt_symbol_indext foo_index{"foo"};
CHECK(foo_index == smt_symbol_indext{"foo"});
CHECK(foo_index != smt_symbol_indext{"bar"});
const smt_numeral_indext index_42{42};
CHECK(index_42 == smt_numeral_indext{42});
CHECK(index_42 != smt_numeral_indext{12});
CHECK(index_42 != foo_index);
}
23 changes: 20 additions & 3 deletions unit/solvers/smt2_incremental/smt_terms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,26 @@ TEST_CASE("smt_identifier_termt construction", "[core][smt2_incremental]")

TEST_CASE("smt_identifier_termt getters.", "[core][smt2_incremental]")
{
const smt_identifier_termt identifier{"foo", smt_bool_sortt{}};
CHECK(identifier.identifier() == "foo");
CHECK(identifier.get_sort() == smt_bool_sortt{});
SECTION("Simple identifier")
{
const smt_identifier_termt identifier{"foo", smt_bool_sortt{}};
CHECK(identifier.identifier() == "foo");
CHECK(identifier.get_sort() == smt_bool_sortt{});
CHECK(identifier.indices().empty());
}
SECTION("Indexed identifier")
{
const smt_symbol_indext baz{"baz"};
const smt_numeral_indext index_42{42};
const smt_identifier_termt indexed{
"bar", smt_bit_vector_sortt{8}, {baz, index_42}};
CHECK(indexed.identifier() == "bar");
CHECK(indexed.get_sort() == smt_bit_vector_sortt{8});
const auto indices = indexed.indices();
REQUIRE(indices.size() == 2);
CHECK(indices[0].get() == baz);
CHECK(indices[1].get() == index_42);
}
}

TEST_CASE("smt_bit_vector_constant_termt getters.", "[core][smt2_incremental]")
Expand Down