Skip to content

Commit fbffc09

Browse files
committed
Add smt_termt visitor
Because this will be used to implement term to string/output stream conversion. Use of a visitor will allow this printing code to be separated from the data structure.
1 parent 1b5e426 commit fbffc09

File tree

4 files changed

+154
-0
lines changed

4 files changed

+154
-0
lines changed

src/solvers/smt2_incremental/smt_terms.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,26 @@ smt_function_application_termt::arguments() const
141141
return std::cref(static_cast<const smt_termt &>(argument));
142142
});
143143
}
144+
145+
template <typename visitort>
146+
void accept(const smt_termt &term, const irep_idt &id, visitort &&visitor)
147+
{
148+
#define TERM_ID(the_id) \
149+
if(id == ID_smt_##the_id##_term) \
150+
return visitor.visit(static_cast<const smt_##the_id##_termt &>(term));
151+
// The include below is marked as nolint because including the same file
152+
// multiple times is required as part of the x macro pattern.
153+
#include <solvers/smt2_incremental/smt_terms.def> // NOLINT(build/include)
154+
#undef TERM_ID
155+
UNREACHABLE;
156+
}
157+
158+
void smt_termt::accept(smt_term_const_downcast_visitort &visitor) const
159+
{
160+
::accept(*this, id(), visitor);
161+
}
162+
163+
void smt_termt::accept(smt_term_const_downcast_visitort &&visitor) const
164+
{
165+
::accept(*this, id(), std::move(visitor));
166+
}

src/solvers/smt2_incremental/smt_terms.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
/// set of terms which are implemented and it is used to automate repetitive
55
/// parts of the implementation. These include -
66
/// * The constant `irep_idt`s used to identify each of the term classes.
7+
/// * The member functions of the `smt_term_const_downcast_visitort` class.
8+
/// * The type of term checks required to implement `smt_termt::accept`.
79
TERM_ID(bool_literal)
810
TERM_ID(not)
911
TERM_ID(identifier)

src/solvers/smt2_incremental/smt_terms.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
class BigInt;
1212
using mp_integer = BigInt;
1313

14+
class smt_term_const_downcast_visitort;
1415
class smt_termt : protected irept, private smt_sortt::storert<smt_termt>
1516
{
1617
public:
@@ -25,6 +26,9 @@ class smt_termt : protected irept, private smt_sortt::storert<smt_termt>
2526

2627
const smt_sortt &get_sort() const;
2728

29+
void accept(smt_term_const_downcast_visitort &) const;
30+
void accept(smt_term_const_downcast_visitort &&) const;
31+
2832
protected:
2933
smt_termt(irep_idt id, smt_sortt sort);
3034
};
@@ -89,4 +93,12 @@ class smt_function_application_termt : public smt_termt
8993
std::vector<std::reference_wrapper<const smt_termt>> arguments() const;
9094
};
9195

96+
class smt_term_const_downcast_visitort
97+
{
98+
public:
99+
#define TERM_ID(the_id) virtual void visit(const smt_##the_id##_termt &) = 0;
100+
#include "smt_terms.def"
101+
#undef TERM_ID
102+
};
103+
92104
#endif // CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_TERMS_H

unit/solvers/smt2_incremental/smt_terms.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,120 @@ TEST_CASE("smt_termt equality.", "[core][smt2_incremental]")
8484
CHECK_FALSE(not_false == not_true);
8585
CHECK(not_false == smt_not_termt{smt_bool_literal_termt{false}});
8686
}
87+
88+
template <typename expected_termt>
89+
class term_visit_type_checkert final : public smt_term_const_downcast_visitort
90+
{
91+
public:
92+
bool expected_term_visited = false;
93+
bool unexpected_term_visited = false;
94+
95+
void visit(const smt_bool_literal_termt &) override
96+
{
97+
if(std::is_same<expected_termt, smt_bool_literal_termt>::value)
98+
{
99+
expected_term_visited = true;
100+
}
101+
else
102+
{
103+
unexpected_term_visited = true;
104+
}
105+
}
106+
107+
void visit(const smt_not_termt &) override
108+
{
109+
if(std::is_same<expected_termt, smt_not_termt>::value)
110+
{
111+
expected_term_visited = true;
112+
}
113+
else
114+
{
115+
unexpected_term_visited = true;
116+
}
117+
}
118+
119+
void visit(const smt_identifier_termt &) override
120+
{
121+
if(std::is_same<expected_termt, smt_identifier_termt>::value)
122+
{
123+
expected_term_visited = true;
124+
}
125+
else
126+
{
127+
unexpected_term_visited = true;
128+
}
129+
}
130+
131+
void visit(const smt_bit_vector_constant_termt &) override
132+
{
133+
if(std::is_same<expected_termt, smt_bit_vector_constant_termt>::value)
134+
{
135+
expected_term_visited = true;
136+
}
137+
else
138+
{
139+
unexpected_term_visited = true;
140+
}
141+
}
142+
143+
void visit(const smt_function_application_termt &) override
144+
{
145+
if(std::is_same<expected_termt, smt_function_application_termt>::value)
146+
{
147+
expected_term_visited = true;
148+
}
149+
else
150+
{
151+
unexpected_term_visited = true;
152+
}
153+
}
154+
};
155+
156+
template <typename term_typet>
157+
term_typet make_test_term();
158+
159+
template <>
160+
smt_bool_literal_termt make_test_term<smt_bool_literal_termt>()
161+
{
162+
return smt_bool_literal_termt{false};
163+
}
164+
165+
template <>
166+
smt_not_termt make_test_term<smt_not_termt>()
167+
{
168+
return smt_not_termt{smt_bool_literal_termt{false}};
169+
}
170+
171+
template <>
172+
smt_identifier_termt make_test_term<smt_identifier_termt>()
173+
{
174+
return smt_identifier_termt{"foo", smt_bool_sortt{}};
175+
}
176+
177+
template <>
178+
smt_bit_vector_constant_termt make_test_term<smt_bit_vector_constant_termt>()
179+
{
180+
return smt_bit_vector_constant_termt{0, 32};
181+
}
182+
183+
template <>
184+
smt_function_application_termt make_test_term<smt_function_application_termt>()
185+
{
186+
return smt_function_application_termt{
187+
smt_identifier_termt{"bar", smt_bool_sortt{}}, {}};
188+
}
189+
190+
TEMPLATE_TEST_CASE(
191+
"smt_termt::accept(visitor)",
192+
"[core][smt2_incremental]",
193+
smt_bool_literal_termt,
194+
smt_not_termt,
195+
smt_identifier_termt,
196+
smt_bit_vector_constant_termt,
197+
smt_function_application_termt)
198+
{
199+
term_visit_type_checkert<TestType> checker;
200+
make_test_term<TestType>().accept(checker);
201+
CHECK(checker.expected_term_visited);
202+
CHECK_FALSE(checker.unexpected_term_visited);
203+
}

0 commit comments

Comments
 (0)