Skip to content

Commit 5420b97

Browse files
committed
introduce zero_extend expression
This introduces the zero_extend expression, which, given a bit-vector operand and a type, either a) pads the given operand with zeros from the left if the given type is wider than the type of the operand, or b) truncates the operand to the width of the given type if the given type is smaller than the operand, or c) reinterprets the operand as having the given type if the width of the type and the width of the operand match. This may differ from conversion if the types have different bit representations. This is easier to read and less prone to error than the current pattern, in which the operand is 1) converted to an unsigned type of the same width, and then 2) casted to an unsigned type of the wider width, and 3) finally casted to the target type.
1 parent 20a1ecf commit 5420b97

13 files changed

+137
-17
lines changed

src/solvers/flattening/boolbv.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ bvt boolbvt::convert_bitvector(const exprt &expr)
165165
return convert_replication(to_replication_expr(expr));
166166
else if(expr.id()==ID_extractbits)
167167
return convert_extractbits(to_extractbits_expr(expr));
168+
else if(expr.id() == ID_zero_extend)
169+
return convert_bitvector(to_zero_extend_expr(expr).lower());
168170
else if(expr.id()==ID_bitnot || expr.id()==ID_bitand ||
169171
expr.id()==ID_bitor || expr.id()==ID_bitxor ||
170172
expr.id()==ID_bitxnor || expr.id()==ID_bitnor ||

src/solvers/floatbv/float_bv.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,10 @@ exprt float_bvt::mul(
692692

693693
// zero-extend the fractions (unpacked fraction has the hidden bit)
694694
typet new_fraction_type=unsignedbv_typet((spec.f+1)*2);
695-
const exprt fraction1=typecast_exprt(unpacked1.fraction, new_fraction_type);
696-
const exprt fraction2=typecast_exprt(unpacked2.fraction, new_fraction_type);
695+
const exprt fraction1 =
696+
zero_extend_exprt{unpacked1.fraction, new_fraction_type};
697+
const exprt fraction2 =
698+
zero_extend_exprt{unpacked2.fraction, new_fraction_type};
697699

698700
// multiply the fractions
699701
unbiased_floatt result;
@@ -750,7 +752,7 @@ exprt float_bvt::div(
750752
unsignedbv_typet(div_width));
751753

752754
// zero-extend fraction2 to match fraction1
753-
const typecast_exprt fraction2(unpacked2.fraction, fraction1.type());
755+
const zero_extend_exprt fraction2{unpacked2.fraction, fraction1.type()};
754756

755757
// divide fractions
756758
unbiased_floatt result;

src/solvers/smt2/smt2_conv.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,6 +2456,10 @@ void smt2_convt::convert_expr(const exprt &expr)
24562456
{
24572457
convert_expr(simplify_expr(to_bitreverse_expr(expr).lower(), ns));
24582458
}
2459+
else if(expr.id() == ID_zero_extend)
2460+
{
2461+
convert_expr(to_zero_extend_expr(expr).lower());
2462+
}
24592463
else if(expr.id() == ID_function_application)
24602464
{
24612465
const auto &function_application_expr = to_function_application_expr(expr);

src/solvers/smt2_incremental/convert_expr_to_smt.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,15 @@ static smt_termt convert_expr_to_smt(
14691469
count_trailing_zeros.pretty());
14701470
}
14711471

1472+
static smt_termt convert_expr_to_smt(
1473+
const zero_extend_exprt &zero_extend,
1474+
const sub_expression_mapt &converted)
1475+
{
1476+
UNREACHABLE_BECAUSE(
1477+
"zero_extend expression should have been lowered by the decision "
1478+
"procedure before conversion to smt terms");
1479+
}
1480+
14721481
static smt_termt convert_expr_to_smt(
14731482
const prophecy_r_or_w_ok_exprt &prophecy_r_or_w_ok,
14741483
const sub_expression_mapt &converted)
@@ -1822,6 +1831,10 @@ static smt_termt dispatch_expr_to_smt_conversion(
18221831
{
18231832
return convert_expr_to_smt(*count_trailing_zeros, converted);
18241833
}
1834+
if(const auto zero_extend = expr_try_dynamic_cast<zero_extend_exprt>(expr))
1835+
{
1836+
return convert_expr_to_smt(*zero_extend, converted);
1837+
}
18251838
if(
18261839
const auto prophecy_r_or_w_ok =
18271840
expr_try_dynamic_cast<prophecy_r_or_w_ok_exprt>(expr))

src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "smt2_incremental_decision_procedure.h"
44

55
#include <util/arith_tools.h>
6+
#include <util/bitvector_expr.h>
67
#include <util/byte_operators.h>
78
#include <util/c_types.h>
89
#include <util/range.h>
@@ -296,6 +297,17 @@ static exprt lower_rw_ok_pointer_in_range(exprt expr, const namespacet &ns)
296297
return expr;
297298
}
298299

300+
static exprt lower_zero_extend(exprt expr, const namespacet &ns)
301+
{
302+
expr.visit_pre([](exprt &expr) {
303+
if(auto zero_extend = expr_try_dynamic_cast<zero_extend_exprt>(expr))
304+
{
305+
expr = zero_extend->lower();
306+
}
307+
});
308+
return expr;
309+
}
310+
299311
void smt2_incremental_decision_proceduret::ensure_handle_for_expr_defined(
300312
const exprt &in_expr)
301313
{
@@ -677,8 +689,10 @@ void smt2_incremental_decision_proceduret::define_object_properties()
677689

678690
exprt smt2_incremental_decision_proceduret::lower(exprt expression) const
679691
{
680-
const exprt lowered = struct_encoding.encode(lower_enum(
681-
lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns),
692+
const exprt lowered = struct_encoding.encode(lower_zero_extend(
693+
lower_enum(
694+
lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns),
695+
ns),
682696
ns));
683697
log.conditional_output(log.debug(), [&](messaget::mstreamt &debug) {
684698
if(lowered != expression)

src/util/bitvector_expr.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ exprt update_bit_exprt::lower() const
5454
typecast_exprt(src(), src_bv_type), bitnot_exprt(mask_shifted));
5555

5656
// zero-extend the replacement bit to match src
57-
auto new_value_casted = typecast_exprt(
58-
typecast_exprt(new_value(), unsignedbv_typet(width)), src_bv_type);
57+
auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type};
5958

6059
// shift the replacement bits
6160
auto new_value_shifted = shl_exprt(new_value_casted, index());
@@ -85,7 +84,7 @@ exprt update_bits_exprt::lower() const
8584
bitand_exprt(typecast_exprt(src(), src_bv_type), mask_shifted);
8685

8786
// zero-extend or shrink the replacement bits to match src
88-
auto new_value_casted = typecast_exprt(new_value(), src_bv_type);
87+
auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type};
8988

9089
// shift the replacement bits
9190
auto new_value_shifted = shl_exprt(new_value_casted, index());
@@ -279,3 +278,19 @@ exprt find_first_set_exprt::lower() const
279278

280279
return typecast_exprt::conditional_cast(result, type());
281280
}
281+
282+
exprt zero_extend_exprt::lower() const
283+
{
284+
const auto old_width = to_bitvector_type(op().type()).get_width();
285+
const auto new_width = to_bitvector_type(type()).get_width();
286+
287+
if(new_width > old_width)
288+
{
289+
return concatenation_exprt{
290+
bv_typet{new_width - old_width}.all_zeros_expr(), op(), type()};
291+
}
292+
else // new_width <= old_width
293+
{
294+
return extractbits_exprt{op(), 0, type()};
295+
}
296+
}

src/util/bitvector_expr.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,4 +1663,48 @@ inline find_first_set_exprt &to_find_first_set_expr(exprt &expr)
16631663
return ret;
16641664
}
16651665

1666+
/// \brief zero extension
1667+
/// The operand is converted to the given type by either
1668+
/// a) truncating if the new type is shorter, or
1669+
/// b) padding with most-significant zero bits if the new type is larger, or
1670+
/// c) reinterprets the operand as the given type if their widths match.
1671+
class zero_extend_exprt : public unary_exprt
1672+
{
1673+
public:
1674+
zero_extend_exprt(exprt _op, typet _type)
1675+
: unary_exprt(ID_zero_extend, std::move(_op), std::move(_type))
1676+
{
1677+
}
1678+
1679+
// a lowering to extraction or concatenation
1680+
exprt lower() const;
1681+
};
1682+
1683+
template <>
1684+
inline bool can_cast_expr<zero_extend_exprt>(const exprt &base)
1685+
{
1686+
return base.id() == ID_zero_extend;
1687+
}
1688+
1689+
/// \brief Cast an exprt to a \ref zero_extend_exprt
1690+
///
1691+
/// \a expr must be known to be \ref zero_extend_exprt.
1692+
///
1693+
/// \param expr: Source expression
1694+
/// \return Object of type \ref zero_extend_exprt
1695+
inline const zero_extend_exprt &to_zero_extend_expr(const exprt &expr)
1696+
{
1697+
PRECONDITION(expr.id() == ID_zero_extend);
1698+
zero_extend_exprt::check(expr);
1699+
return static_cast<const zero_extend_exprt &>(expr);
1700+
}
1701+
1702+
/// \copydoc to_zero_extend_expr(const exprt &)
1703+
inline zero_extend_exprt &to_zero_extend_expr(exprt &expr)
1704+
{
1705+
PRECONDITION(expr.id() == ID_zero_extend);
1706+
zero_extend_exprt::check(expr);
1707+
return static_cast<zero_extend_exprt &>(expr);
1708+
}
1709+
16661710
#endif // CPROVER_UTIL_BITVECTOR_EXPR_H

src/util/format_expr.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,12 @@ void format_expr_configt::setup()
376376
<< format(expr.type()) << ')';
377377
};
378378

379+
expr_map[ID_zero_extend] =
380+
[](std::ostream &os, const exprt &expr) -> std::ostream & {
381+
return os << "zero_extend(" << format(to_zero_extend_expr(expr).op())
382+
<< ", " << format(expr.type()) << ')';
383+
};
384+
379385
expr_map[ID_floatbv_typecast] =
380386
[](std::ostream &os, const exprt &expr) -> std::ostream & {
381387
const auto &floatbv_typecast_expr = to_floatbv_typecast_expr(expr);

src/util/irep_ids.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ IREP_ID_ONE(extractbit)
188188
IREP_ID_ONE(extractbits)
189189
IREP_ID_ONE(update_bit)
190190
IREP_ID_ONE(update_bits)
191+
IREP_ID_ONE(zero_extend)
191192
IREP_ID_TWO(C_reference, #reference)
192193
IREP_ID_TWO(C_rvalue_reference, #rvalue_reference)
193194
IREP_ID_ONE(true)

src/util/lower_byte_operators.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,15 +2491,16 @@ static exprt lower_byte_update(
24912491
exprt zero_extended;
24922492
if(bit_width > update_size_bits)
24932493
{
2494-
zero_extended = concatenation_exprt{
2495-
bv_typet{bit_width - update_size_bits}.all_zeros_expr(),
2496-
value,
2497-
bv_typet{bit_width}};
2498-
2499-
if(!is_little_endian)
2500-
to_concatenation_expr(zero_extended)
2501-
.op0()
2502-
.swap(to_concatenation_expr(zero_extended).op1());
2494+
if(is_little_endian)
2495+
zero_extended = zero_extend_exprt{value, bv_typet{bit_width}};
2496+
else
2497+
{
2498+
// Big endian -- the zero is added as LSB.
2499+
zero_extended = concatenation_exprt{
2500+
value,
2501+
bv_typet{bit_width - update_size_bits}.all_zeros_expr(),
2502+
bv_typet{bit_width}};
2503+
}
25032504
}
25042505
else
25052506
zero_extended = value;

src/util/simplify_expr.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3028,6 +3028,10 @@ simplify_exprt::resultt<> simplify_exprt::simplify_node(const exprt &node)
30283028
{
30293029
r = simplify_extractbits(to_extractbits_expr(expr));
30303030
}
3031+
else if(expr.id() == ID_zero_extend)
3032+
{
3033+
r = simplify_zero_extend(to_zero_extend_expr(expr));
3034+
}
30313035
else if(expr.id()==ID_ieee_float_equal ||
30323036
expr.id()==ID_ieee_float_notequal)
30333037
{

src/util/simplify_expr_class.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class unary_overflow_exprt;
7676
class unary_plus_exprt;
7777
class update_exprt;
7878
class with_exprt;
79+
class zero_extend_exprt;
7980

8081
class simplify_exprt
8182
{
@@ -152,6 +153,7 @@ class simplify_exprt
152153
[[nodiscard]] resultt<> simplify_extractbit(const extractbit_exprt &);
153154
[[nodiscard]] resultt<> simplify_extractbits(const extractbits_exprt &);
154155
[[nodiscard]] resultt<> simplify_concatenation(const concatenation_exprt &);
156+
[[nodiscard]] resultt<> simplify_zero_extend(const zero_extend_exprt &);
155157
[[nodiscard]] resultt<> simplify_mult(const mult_exprt &);
156158
[[nodiscard]] resultt<> simplify_div(const div_exprt &);
157159
[[nodiscard]] resultt<> simplify_mod(const mod_exprt &);

src/util/simplify_expr_int.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,18 @@ simplify_exprt::simplify_concatenation(const concatenation_exprt &expr)
997997
return std::move(new_expr);
998998
}
999999

1000+
simplify_exprt::resultt<>
1001+
simplify_exprt::simplify_zero_extend(const zero_extend_exprt &expr)
1002+
{
1003+
if(!can_cast_type<bitvector_typet>(expr.type()))
1004+
return unchanged(expr);
1005+
1006+
if(!can_cast_type<bitvector_typet>(expr.op().type()))
1007+
return unchanged(expr);
1008+
1009+
return changed(simplify_node(expr.lower()));
1010+
}
1011+
10001012
simplify_exprt::resultt<>
10011013
simplify_exprt::simplify_shifts(const shift_exprt &expr)
10021014
{

0 commit comments

Comments
 (0)