Skip to content

Add bit vector extract operation support to incremental SMT2 solving #6654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 11, 2022
Merged
1 change: 1 addition & 0 deletions src/solvers/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ SRC = $(BOOLEFORCE_SRC) \
smt2_incremental/smt_bit_vector_theory.cpp \
smt2_incremental/smt_commands.cpp \
smt2_incremental/smt_core_theory.cpp \
smt2_incremental/smt_index.cpp \
smt2_incremental/smt_logics.cpp \
smt2_incremental/smt_options.cpp \
smt2_incremental/smt_response_validation.cpp \
Expand Down
31 changes: 31 additions & 0 deletions src/solvers/smt2_incremental/smt_bit_vector_theory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,37 @@

#include <util/invariant.h>

const char *smt_bit_vector_theoryt::extractt::identifier()
{
return "extract";
}

smt_sortt
smt_bit_vector_theoryt::extractt::return_sort(const smt_termt &operand) const
{
return smt_bit_vector_sortt{i - j + 1};
}

std::vector<smt_indext> smt_bit_vector_theoryt::extractt::indices() const
{
return {smt_numeral_indext{i}, smt_numeral_indext{j}};
}

void smt_bit_vector_theoryt::extractt::validate(const smt_termt &operand) const
{
PRECONDITION(i >= j);
const auto bit_vector_sort = operand.get_sort().cast<smt_bit_vector_sortt>();
PRECONDITION(bit_vector_sort);
PRECONDITION(i < bit_vector_sort->bit_width());
}

smt_function_application_termt::factoryt<smt_bit_vector_theoryt::extractt>
smt_bit_vector_theoryt::extract(std::size_t i, std::size_t j)
{
PRECONDITION(i >= j);
return smt_function_application_termt::factoryt<extractt>(i, j);
}

static void validate_bit_vector_operator_arguments(
const smt_termt &left,
const smt_termt &right)
Expand Down
12 changes: 12 additions & 0 deletions src/solvers/smt2_incremental/smt_bit_vector_theory.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
class smt_bit_vector_theoryt
{
public:
struct extractt final
{
std::size_t i;
std::size_t j;
static const char *identifier();
smt_sortt return_sort(const smt_termt &operand) const;
std::vector<smt_indext> indices() const;
void validate(const smt_termt &operand) const;
};
static smt_function_application_termt::factoryt<extractt>
extract(std::size_t i, std::size_t j);

// Relational operator class declarations
struct unsigned_less_thant final
{
Expand Down
65 changes: 65 additions & 0 deletions src/solvers/smt2_incremental/smt_index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Author: Diffblue Ltd.

#include "smt_index.h"

// Define the irep_idts for kinds of index.
const irep_idt ID_smt_numeral_index{"smt_numeral_index"};
const irep_idt ID_smt_symbol_index{"smt_symbol_index"};

bool smt_indext::operator==(const smt_indext &other) const
{
return irept::operator==(other);
}

bool smt_indext::operator!=(const smt_indext &other) const
{
return !(*this == other);
}

template <>
const smt_numeral_indext *smt_indext::cast<smt_numeral_indext>() const &
{
return id() == ID_smt_numeral_index
? static_cast<const smt_numeral_indext *>(this)
: nullptr;
}

template <>
const smt_symbol_indext *smt_indext::cast<smt_symbol_indext>() const &
{
return id() == ID_smt_symbol_index
? static_cast<const smt_symbol_indext *>(this)
: nullptr;
}

void smt_indext::accept(smt_index_const_downcast_visitort &visitor) const
{
if(const auto numeral = this->cast<smt_numeral_indext>())
return visitor.visit(*numeral);
if(const auto symbol = this->cast<smt_symbol_indext>())
return visitor.visit(*symbol);
UNREACHABLE;
}

smt_numeral_indext::smt_numeral_indext(std::size_t value)
: smt_indext{ID_smt_numeral_index}
{
set_size_t(ID_value, value);
}

std::size_t smt_numeral_indext::value() const
{
return get_size_t(ID_value);
}

smt_symbol_indext::smt_symbol_indext(irep_idt identifier)
: smt_indext{ID_smt_symbol_index}
{
PRECONDITION(!identifier.empty());
set(ID_identifier, identifier);
}

irep_idt smt_symbol_indext::identifier() const
{
return get(ID_identifier);
}
90 changes: 90 additions & 0 deletions src/solvers/smt2_incremental/smt_index.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Author: Diffblue Ltd.

#ifndef CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_INDEX_H
#define CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_INDEX_H

#include <util/irep.h>

class smt_index_const_downcast_visitort;

/// \brief
/// For implementation of indexed identifiers. See SMT-LIB Standard Version 2.6
/// section 3.3.
class smt_indext : protected irept
{
public:
// smt_indext does not support the notion of an empty / null state. Use
// optionalt<smt_indext> instead if an empty index is required.
smt_indext() = delete;

using irept::pretty;

bool operator==(const smt_indext &) const;
bool operator!=(const smt_indext &) const;

template <typename sub_classt>
const sub_classt *cast() const &;

void accept(smt_index_const_downcast_visitort &) const;

/// \brief Class for adding the ability to up and down cast smt_indext to and
/// from irept. These casts are required by other irept derived classes in
/// order to store instances of smt_termt inside them.
/// \tparam derivedt The type of class which derives from this class and from
/// irept.
template <typename derivedt>
class storert
{
protected:
storert();
static irept upcast(smt_indext index);
static const smt_indext &downcast(const irept &);
};

protected:
using irept::irept;
};

template <typename derivedt>
smt_indext::storert<derivedt>::storert()
{
static_assert(
std::is_base_of<irept, derivedt>::value &&
std::is_base_of<storert<derivedt>, derivedt>::value,
"Only irept based classes need to upcast smt_sortt to store it.");
}

template <typename derivedt>
irept smt_indext::storert<derivedt>::upcast(smt_indext index)
{
return static_cast<irept &&>(std::move(index));
}

template <typename derivedt>
const smt_indext &smt_indext::storert<derivedt>::downcast(const irept &irep)
{
return static_cast<const smt_indext &>(irep);
}

class smt_numeral_indext : public smt_indext
{
public:
explicit smt_numeral_indext(std::size_t value);
std::size_t value() const;
};

class smt_symbol_indext : public smt_indext
{
public:
explicit smt_symbol_indext(irep_idt identifier);
irep_idt identifier() const;
};

class smt_index_const_downcast_visitort
{
public:
virtual void visit(const smt_numeral_indext &) = 0;
virtual void visit(const smt_symbol_indext &) = 0;
};

#endif // CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_INDEX_H
19 changes: 18 additions & 1 deletion src/solvers/smt2_incremental/smt_terms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ static bool is_valid_smt_identifier(irep_idt identifier)
return std::regex_match(id2string(identifier), valid);
}

smt_identifier_termt::smt_identifier_termt(irep_idt identifier, smt_sortt sort)
smt_identifier_termt::smt_identifier_termt(
irep_idt identifier,
smt_sortt sort,
std::vector<smt_indext> indices)
: smt_termt(ID_smt_identifier_term, std::move(sort))
{
// The below invariant exists as a sanity check that the string being used for
Expand All @@ -67,13 +70,27 @@ smt_identifier_termt::smt_identifier_termt(irep_idt identifier, smt_sortt sort)
is_valid_smt_identifier(identifier),
R"(Identifiers may not contain | characters.)");
set(ID_identifier, identifier);
for(auto &index : indices)
{
get_sub().push_back(
smt_indext::storert<smt_identifier_termt>::upcast(std::move(index)));
}
}

irep_idt smt_identifier_termt::identifier() const
{
return get(ID_identifier);
}

std::vector<std::reference_wrapper<const smt_indext>>
smt_identifier_termt::indices() const
{
return make_range(get_sub()).map([](const irept &index) {
return std::cref(
smt_indext::storert<smt_identifier_termt>::downcast(index));
});
}

smt_bit_vector_constant_termt::smt_bit_vector_constant_termt(
const mp_integer &value,
smt_bit_vector_sortt sort)
Expand Down
66 changes: 61 additions & 5 deletions src/solvers/smt2_incremental/smt_terms.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
#ifndef CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_TERMS_H
#define CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_TERMS_H

#include <solvers/smt2_incremental/smt_sorts.h>
#include <util/irep.h>

#include <solvers/smt2_incremental/smt_index.h>
#include <solvers/smt2_incremental/smt_sorts.h>
#include <solvers/smt2_incremental/type_traits.h>

#include <functional>
#include <utility>

class BigInt;
using mp_integer = BigInt;
Expand Down Expand Up @@ -77,7 +81,14 @@ class smt_bool_literal_termt : public smt_termt

/// Stores identifiers in unescaped and unquoted form. Any escaping or quoting
/// required should be performed during printing.
class smt_identifier_termt : public smt_termt
/// \details
/// The SMT-LIB standard Version 2.6 refers to "indexed" identifiers which have
/// 1 or more indices and "simple" identifiers which have no indicies. The
/// internal `smt_identifier_termt` class is used for both kinds of identifier
/// which are distinguished based on whether the collection of `indices` is
/// empty or not.
class smt_identifier_termt : public smt_termt,
private smt_indext::storert<smt_identifier_termt>
{
public:
/// \brief Constructs an identifier term with the given \p identifier and
Expand All @@ -91,8 +102,14 @@ class smt_identifier_termt : public smt_termt
/// \param sort: The sort which this term will have. All terms in our abstract
/// form must be sorted, even if those sorts are not printed in all
/// contexts.
smt_identifier_termt(irep_idt identifier, smt_sortt sort);
/// \param indices: This should be collection of indices for an indexed
/// identifier, or an empty collection for simple (non-indexed) identifiers.
smt_identifier_termt(
irep_idt identifier,
smt_sortt sort,
std::vector<smt_indext> indices = {});
irep_idt identifier() const;
std::vector<std::reference_wrapper<const smt_indext>> indices() const;
};

class smt_bit_vector_constant_termt : public smt_termt
Expand Down Expand Up @@ -121,6 +138,44 @@ class smt_function_application_termt : public smt_termt
smt_identifier_termt function_identifier,
std::vector<smt_termt> arguments);

// This is used to detect if \p functiont has an `indicies` member function.
// It will resolve to std::true_type if it does or std::false type otherwise.
template <class functiont, class = void>
struct has_indicest : std::false_type
{
};

template <class functiont>
struct has_indicest<
functiont,
void_t<decltype(std::declval<functiont>().indices())>> : std::true_type
{
};

/// Overload for when \p functiont does not have indices.
template <class functiont>
static std::vector<smt_indext>
indices(const functiont &function, const std::false_type &has_indices)
{
return {};
}

/// Overload for when \p functiont has indices member function.
template <class functiont>
static std::vector<smt_indext>
indices(const functiont &function, const std::true_type &has_indices)
{
return function.indices();
}

/// Returns `function.indices` if `functiont` has an `indices` member function
/// or returns an empty collection otherwise.
template <class functiont>
static std::vector<smt_indext> indices(const functiont &function)
{
return indices(function, has_indicest<functiont>{});
}

public:
const smt_identifier_termt &function_identifier() const;
std::vector<std::reference_wrapper<const smt_termt>> arguments() const;
Expand All @@ -133,7 +188,7 @@ class smt_function_application_termt : public smt_termt

public:
template <typename... function_type_argument_typest>
explicit factoryt(function_type_argument_typest &&... arguments)
explicit factoryt(function_type_argument_typest &&... arguments) noexcept
: function{std::forward<function_type_argument_typest>(arguments)...}
{
}
Expand All @@ -145,7 +200,8 @@ class smt_function_application_termt : public smt_termt
function.validate(arguments...);
auto return_sort = function.return_sort(arguments...);
return smt_function_application_termt{
smt_identifier_termt{function.identifier(), std::move(return_sort)},
smt_identifier_termt{
function.identifier(), std::move(return_sort), indices(function)},
{std::forward<argument_typest>(arguments)...}};
}
};
Expand Down
Loading