Skip to content

Commit b0744ea

Browse files
authored
Merge pull request #7080 from qinheping/loop_invariant_synthesis
Introduce expression enumerators
2 parents 991371d + f8e6230 commit b0744ea

File tree

6 files changed

+1090
-0
lines changed

6 files changed

+1090
-0
lines changed

src/goto-instrument/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ SRC = accelerate/accelerate.cpp \
7070
splice_call.cpp \
7171
stack_depth.cpp \
7272
synthesizer/enumerative_loop_invariant_synthesizer.cpp \
73+
synthesizer/expr_enumerator.cpp \
7374
synthesizer/synthesizer_utils.cpp \
7475
thread_instrumentation.cpp \
7576
undefined_functions.cpp \
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
/*******************************************************************\
2+
Module: Enumerator Interface
3+
Author: Qinheping Hu
4+
\*******************************************************************/
5+
6+
#include "expr_enumerator.h"
7+
8+
#include <util/format_expr.h>
9+
#include <util/simplify_expr.h>
10+
11+
expr_sett leaf_enumeratort::enumerate(const std::size_t size) const
12+
{
13+
// Size of leaf expressions must be 1.
14+
if(size != 1)
15+
return {};
16+
17+
return leaf_exprs;
18+
}
19+
20+
expr_sett non_leaf_enumeratort::enumerate(const std::size_t size) const
21+
{
22+
expr_sett result;
23+
24+
// Enumerate nothing when `size` is too small to be partitioned.
25+
if(size - 1 < arity)
26+
return result;
27+
28+
// For every possible partition, set `size` of
29+
// each sub-enumerator to be the corresponding component in the partition.
30+
for(const auto &partition : get_partitions(size - 1, arity))
31+
{
32+
if(!is_good_partition(partition))
33+
continue;
34+
35+
// Compute the Cartesian product as result.
36+
for(const auto &product_tuple : cartesian_product_of_enumerators(
37+
sub_enumerators,
38+
sub_enumerators.begin(),
39+
partition,
40+
partition.begin()))
41+
{
42+
// Optimization: rule out equivalent expressions
43+
// using certain equivalence class.
44+
// Keep only representation tuple of each equivalence class.
45+
if(is_equivalence_class_representation(product_tuple))
46+
result.insert(simplify_expr(instantiate(product_tuple), ns));
47+
}
48+
}
49+
50+
return result;
51+
}
52+
53+
std::set<expr_listt> non_leaf_enumeratort::cartesian_product_of_enumerators(
54+
const enumeratorst &enumerators,
55+
const enumeratorst::const_iterator &it_enumerators,
56+
const partitiont &partition,
57+
const partitiont::const_iterator &it_partition) const
58+
{
59+
INVARIANT(
60+
std::distance(it_enumerators, enumerators.end()) ==
61+
std::distance(it_partition, partition.end()),
62+
"Partition should have the same size as enumerators.");
63+
64+
std::set<expr_listt> result;
65+
66+
if(std::next(it_enumerators) == enumerators.end())
67+
{
68+
/// Current enumerator is the last enumerator.
69+
/// Add all expressions enumerated by `it_enumerators` to `result`.
70+
for(const auto &e : enumerators.back()->enumerate(*it_partition))
71+
{
72+
result.insert({e});
73+
}
74+
}
75+
else
76+
{
77+
/// First compute the Cartesian product of enumerators after
78+
/// `it_enumerators`. And then append the expressions enumerated by the
79+
/// `it_enumerators` to every list in the Cartesian product.
80+
for(const auto &sub_tuple : cartesian_product_of_enumerators(
81+
enumerators,
82+
std::next(it_enumerators),
83+
partition,
84+
std::next(it_partition)))
85+
{
86+
for(const auto &elem : (*it_enumerators)->enumerate(*it_partition))
87+
{
88+
expr_listt new_tuple(sub_tuple);
89+
new_tuple.emplace_front(elem);
90+
result.insert(new_tuple);
91+
}
92+
}
93+
}
94+
return result;
95+
}
96+
97+
std::list<partitiont>
98+
get_partitions_long(const std::size_t n, const std::size_t k)
99+
{
100+
std::list<partitiont> result;
101+
// Cuts are an increasing vector of distinct indexes between 0 and n.
102+
// Note that cuts[0] is always 0 and cuts[k+1] is always n.
103+
// There is a bijection between partitions and cuts, i.e., for a given cuts,
104+
// (cuts[1]-cuts[0], cuts[2]-cuts[1], ..., cuts[k+1]-cuts[k])
105+
// is a partition of n into k components.
106+
std::vector<std::size_t> cuts;
107+
108+
// Initialize cuts as (0, n-k+1, n-k+2, ..., n).
109+
// O: elements
110+
// |: cuts
111+
// Initial cuts
112+
// 000...0111...1
113+
// So the first partition is (n-k+1, 1, 1, ..., 1).
114+
cuts.emplace_back(0);
115+
cuts.emplace_back(n - k + 1);
116+
for(std::size_t i = 0; i < k - 1; i++)
117+
{
118+
cuts.emplace_back(n - k + 2 + i);
119+
}
120+
121+
// Done when all cuts were enumerated.
122+
bool done = false;
123+
124+
while(!done)
125+
{
126+
// Construct a partition from cuts using the bijection described above.
127+
partitiont new_partition = partitiont();
128+
for(std::size_t i = 1; i < k + 1; i++)
129+
{
130+
new_partition.emplace_back(cuts[i] - cuts[i - 1]);
131+
}
132+
133+
// We move to the next cuts. The idea is that
134+
// 1. we first find the largest index i such that there are space before
135+
// cuts[i] where cuts[i] can be moved to;
136+
// The index i is the rightmost index we move in this iteration.
137+
// 2. we then move cuts[i] to its left by 1;
138+
// 3. move all cuts next to cuts[rightmost_to_move].
139+
//
140+
// O: filler
141+
// |: cuts
142+
//
143+
// Example:
144+
// Before moving:
145+
// 00000010010111110
146+
// ^
147+
// rightmost_to_move
148+
std::size_t rightmost_to_move = 0;
149+
for(std::size_t i = 1; i < k; i++)
150+
{
151+
if(cuts[i] - cuts[i - 1] > 1)
152+
{
153+
rightmost_to_move = i;
154+
}
155+
}
156+
157+
// Move cuts[rightmost_to_move] to its left:
158+
// 00000010011011110
159+
// ^
160+
// rightmost_to_move
161+
cuts[rightmost_to_move] = cuts[rightmost_to_move] - 1;
162+
163+
// No cut can be moved---we have enumerated all cuts.
164+
if(rightmost_to_move == 0)
165+
done = true;
166+
else
167+
{
168+
// Move all cuts (except for cuts[0]) after rightmost_to_move to their
169+
// rightmost.
170+
// 00000010011001111
171+
// ^
172+
// rightmost_to_move
173+
std::size_t accum = 1;
174+
for(std::size_t i = k - 1; i > rightmost_to_move; i--)
175+
{
176+
cuts[i] = n - accum;
177+
accum++;
178+
}
179+
}
180+
result.emplace_back(new_partition);
181+
}
182+
return result;
183+
}
184+
185+
/// Compute all positions of ones in the bit vector `v` (1-indexed).
186+
std::vector<std::size_t> get_ones_pos(std::size_t v)
187+
{
188+
const std::size_t length = sizeof(std::size_t) * 8;
189+
std::vector<std::size_t> result;
190+
191+
// Start from the lowest bit at position `length`
192+
std::size_t curr_pos = length;
193+
while(v != 0)
194+
{
195+
if(v % 2)
196+
{
197+
result.insert(result.begin(), curr_pos);
198+
}
199+
200+
// Move to the next bit.
201+
v = v >> 1;
202+
curr_pos--;
203+
}
204+
205+
return result;
206+
}
207+
208+
/// Construct parition of `n` elements from a bit vector `v`.
209+
/// For a bit vector with ones at positions (computed by `get_ones_pos`)
210+
/// (ones[0], ones[1], ..., ones[k-2]),
211+
/// the corresponding partition is
212+
/// (ones[0], ones[1]-ones[0], ..., ones[k-2]-ones[k-3], n-ones[k-2]).
213+
partitiont from_bits_to_partition(std::size_t v, std::size_t n)
214+
{
215+
const std::vector<std::size_t> ones_pos = get_ones_pos(v);
216+
217+
INVARIANT(ones_pos.size() >= 1, "There should be at least one bit set in v");
218+
219+
partitiont result = {ones_pos[0]};
220+
221+
for(std::size_t i = 1; i < ones_pos.size(); i++)
222+
{
223+
result.emplace_back(ones_pos[i] - ones_pos[i - 1]);
224+
}
225+
result.emplace_back(n - ones_pos[ones_pos.size() - 1]);
226+
227+
return result;
228+
}
229+
230+
std::list<partitiont> non_leaf_enumeratort::get_partitions(
231+
const std::size_t n,
232+
const std::size_t k) const
233+
{
234+
// Every component should contain at least one element.
235+
if(n < k)
236+
return {};
237+
238+
// Number of bits at all.
239+
const std::size_t length = sizeof(std::size_t) * 8;
240+
241+
// This bithack-based implementation works only for `n` no larger than
242+
// `length`. Use the vector-based implementation `n` is too large.
243+
if(n > length)
244+
return get_partitions_long(n, k);
245+
246+
// We enumerate all bit vectors `v` with k-1 one's such that each component
247+
// corresponds to one unique partition.
248+
// For a bit vector with ones at positions (computed by `get_ones_pos`)
249+
// (ones[0], ones[1], ..., ones[k-2]),
250+
// the corresponding partition is
251+
// (ones[0], ones[1]-ones[0], ..., ones[k-2]-ones[k-3], n-ones[k-2]).
252+
253+
// Initial `v` is with ones at positions (n-k+1, n-k+2, ..., n-2, n-1).
254+
std::size_t v = 0;
255+
// Initial `end` (the last bit vectorr we enumerate) is with ones at
256+
// positions (1, 2, 3, ..., k-1).
257+
std::size_t end = 0;
258+
for(size_t i = 0; i < k - 1; i++)
259+
{
260+
v++;
261+
v = v << 1;
262+
end++;
263+
end = end << 1;
264+
}
265+
v = v << (length - n);
266+
end = end << (length - k);
267+
268+
std::list<partitiont> result;
269+
while(v != end)
270+
{
271+
// Construct the partition for current bit vector and add it to `result`
272+
result.emplace_back(from_bits_to_partition(v, n));
273+
274+
// https://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
275+
// Compute the lexicographically next bit permutation.
276+
std::size_t t = (v | (v - 1)) + 1;
277+
v = t | ((((t & -t) / (v & -v)) >> 1) - 1);
278+
}
279+
result.emplace_back(from_bits_to_partition(v, n));
280+
281+
return result;
282+
}
283+
284+
bool binary_functional_enumeratort::is_commutative(const irep_idt &op) const
285+
{
286+
return op_id == ID_equal || op_id == ID_plus || op_id == ID_notequal ||
287+
op_id == ID_or || op_id == ID_and || op_id == ID_xor ||
288+
op_id == ID_bitand || op_id == ID_bitor || op_id == ID_bitxor ||
289+
op_id == ID_mult;
290+
}
291+
292+
bool binary_functional_enumeratort::is_equivalence_class_representation(
293+
const expr_listt &exprs) const
294+
{
295+
std::stringstream left, right;
296+
left << format(exprs.front());
297+
right << format(exprs.back());
298+
// When the two sub-enumerators are exchangeable---they enumerate the same
299+
// set of expressions---, and the operator is commutative, `exprs` is a
300+
// representation if its sub-expressions are sorted.
301+
if(is_exchangeable && is_commutative(op_id) && left.str() > right.str())
302+
{
303+
return false;
304+
}
305+
306+
/// We are not sure if `exprs` is represented by some other tuple.
307+
return true;
308+
}
309+
310+
exprt binary_functional_enumeratort::instantiate(const expr_listt &exprs) const
311+
{
312+
INVARIANT(
313+
exprs.size() == 2,
314+
"number of arguments should be 2: " + integer2string(exprs.size()));
315+
if(op_id == ID_equal)
316+
return equal_exprt(exprs.front(), exprs.back());
317+
if(op_id == ID_le)
318+
return less_than_or_equal_exprt(exprs.front(), exprs.back());
319+
if(op_id == ID_lt)
320+
return less_than_exprt(exprs.front(), exprs.back());
321+
if(op_id == ID_gt)
322+
return greater_than_exprt(exprs.front(), exprs.back());
323+
if(op_id == ID_ge)
324+
return greater_than_or_equal_exprt(exprs.front(), exprs.back());
325+
if(op_id == ID_and)
326+
return and_exprt(exprs.front(), exprs.back());
327+
if(op_id == ID_or)
328+
return or_exprt(exprs.front(), exprs.back());
329+
if(op_id == ID_plus)
330+
return plus_exprt(exprs.front(), exprs.back());
331+
if(op_id == ID_minus)
332+
return minus_exprt(exprs.front(), exprs.back());
333+
if(op_id == ID_notequal)
334+
return notequal_exprt(exprs.front(), exprs.back());
335+
return binary_exprt(exprs.front(), op_id, exprs.back());
336+
}
337+
338+
expr_sett alternatives_enumeratort::enumerate(const std::size_t size) const
339+
{
340+
expr_sett result;
341+
for(const auto &enumerator : sub_enumerators)
342+
{
343+
for(const auto &e : enumerator->enumerate(size))
344+
{
345+
result.insert(e);
346+
}
347+
}
348+
return result;
349+
}
350+
351+
expr_sett
352+
recursive_enumerator_placeholdert::enumerate(const std::size_t size) const
353+
{
354+
const auto &it = factory.productions_map.find(identifier);
355+
INVARIANT(it != factory.productions_map.end(), "No nonterminal found.");
356+
alternatives_enumeratort actual_enumerator(it->second, ns);
357+
return actual_enumerator.enumerate(size);
358+
}
359+
360+
void enumerator_factoryt::add_placeholder(
361+
const recursive_enumerator_placeholdert &placeholder)
362+
{
363+
// The new placeholder (nonterminal) belongs to this factory (grammar).
364+
const auto &ret = nonterminal_set.insert(placeholder.identifier);
365+
INVARIANT(ret.second, "Duplicated non-terminals");
366+
}
367+
368+
void enumerator_factoryt::attach_productions(
369+
const std::string &id,
370+
const enumeratorst &enumerators)
371+
{
372+
const auto &ret = productions_map.insert({id, enumerators});
373+
INVARIANT(
374+
ret.second, "Cannnot attach enumerators to a non-existing nonterminal.");
375+
}

0 commit comments

Comments
 (0)