diff --git a/src/analyses/constant_propagator.cpp b/src/analyses/constant_propagator.cpp index 452e8355a93..57b854a52ee 100644 --- a/src/analyses/constant_propagator.cpp +++ b/src/analyses/constant_propagator.cpp @@ -31,11 +31,15 @@ void constant_propagator_domaint::assign_rec( valuest &values, const exprt &lhs, const exprt &rhs, - const namespacet &ns) + const namespacet &ns, + const constant_propagator_ait *cp) { if(lhs.id()!=ID_symbol) return; + if(cp && !cp->should_track_value(lhs, ns)) + return; + const symbol_exprt &s=to_symbol_expr(lhs); exprt tmp=rhs; @@ -89,11 +93,11 @@ void constant_propagator_domaint::transform( const code_assignt &assignment=to_code_assign(from->code); const exprt &lhs=assignment.lhs(); const exprt &rhs=assignment.rhs(); - assign_rec(values, lhs, rhs, ns); + assign_rec(values, lhs, rhs, ns, cp); } else if(from->is_assume()) { - two_way_propagate_rec(from->guard, ns); + two_way_propagate_rec(from->guard, ns, cp); } else if(from->is_goto()) { @@ -110,7 +114,7 @@ void constant_propagator_domaint::transform( values.set_to_bottom(); else { - two_way_propagate_rec(g, ns); + two_way_propagate_rec(g, ns, cp); // If two way propagate is enabled then it may be possible to detect // that the branch condition is infeasible and thus the domain should // be set to bottom. Without it the domain will only be set to bottom @@ -175,7 +179,7 @@ void constant_propagator_domaint::transform( break; const symbol_exprt parameter_expr(p_it->get_identifier(), arg.type()); - assign_rec(values, parameter_expr, arg, ns); + assign_rec(values, parameter_expr, arg, ns, cp); ++p_it; } @@ -220,7 +224,8 @@ void constant_propagator_domaint::transform( /// handles equalities and conjunctions containing equalities bool constant_propagator_domaint::two_way_propagate_rec( const exprt &expr, - const namespacet &ns) + const namespacet &ns, + const constant_propagator_ait *cp) { #ifdef DEBUG std::cout << "two_way_propagate_rec: " << format(expr) << '\n'; @@ -238,7 +243,7 @@ bool constant_propagator_domaint::two_way_propagate_rec( change = false; forall_operands(it, expr) - if(two_way_propagate_rec(*it, ns)) + if(two_way_propagate_rec(*it, ns, cp)) change=true; } while(change); diff --git a/src/analyses/constant_propagator.h b/src/analyses/constant_propagator.h index e32a14bc0f2..125942acfe7 100644 --- a/src/analyses/constant_propagator.h +++ b/src/analyses/constant_propagator.h @@ -18,6 +18,8 @@ Author: Peter Schrammel #include "ai.h" #include "dirty.h" +class constant_propagator_ait; + class constant_propagator_domaint:public ai_domain_baset { public: @@ -143,12 +145,15 @@ class constant_propagator_domaint:public ai_domain_baset protected: void assign_rec( valuest &values, - const exprt &lhs, const exprt &rhs, - const namespacet &ns); + const exprt &lhs, + const exprt &rhs, + const namespacet &ns, + const constant_propagator_ait *cp); bool two_way_propagate_rec( const exprt &expr, - const namespacet &ns); + const namespacet &ns, + const constant_propagator_ait *cp); bool partial_evaluate_with_all_rounding_modes( exprt &expr, @@ -160,13 +165,35 @@ class constant_propagator_domaint:public ai_domain_baset class constant_propagator_ait:public ait { public: - explicit constant_propagator_ait(const goto_functionst &goto_functions): - dirty(goto_functions) + typedef std::function + should_track_valuet; + + static bool track_all_values(const exprt &, const namespacet &) + { + return true; + } + + explicit constant_propagator_ait( + const goto_functionst &goto_functions, + should_track_valuet should_track_value = track_all_values): + dirty(goto_functions), + should_track_value(should_track_value) + { + } + + explicit constant_propagator_ait( + const goto_functiont &goto_function, + should_track_valuet should_track_value = track_all_values): + dirty(goto_function), + should_track_value(should_track_value) { } constant_propagator_ait( - goto_modelt &goto_model):dirty(goto_model.goto_functions) + goto_modelt &goto_model, + should_track_valuet should_track_value = track_all_values): + dirty(goto_model.goto_functions), + should_track_value(should_track_value) { const namespacet ns(goto_model.symbol_table); operator()(goto_model.goto_functions, ns); @@ -175,7 +202,10 @@ class constant_propagator_ait:public ait constant_propagator_ait( goto_functionst::goto_functiont &goto_function, - const namespacet &ns):dirty(goto_function) + const namespacet &ns, + should_track_valuet should_track_value = track_all_values): + dirty(goto_function), + should_track_value(should_track_value) { operator()(goto_function, ns); replace(goto_function, ns); @@ -197,6 +227,8 @@ class constant_propagator_ait:public ait void replace_types_rec( const replace_symbolt &replace_const, exprt &expr); + + should_track_valuet should_track_value; }; #endif // CPROVER_ANALYSES_CONSTANT_PROPAGATOR_H diff --git a/unit/Makefile b/unit/Makefile index ca8920e0c9d..496f4cb0143 100644 --- a/unit/Makefile +++ b/unit/Makefile @@ -9,6 +9,7 @@ SRC = unit_tests.cpp \ SRC += unit_tests.cpp \ analyses/ai/ai_simplify_lhs.cpp \ analyses/call_graph.cpp \ + analyses/constant_propagator.cpp \ analyses/disconnect_unreachable_nodes_in_graph.cpp \ analyses/does_remove_const/does_expr_lose_const.cpp \ analyses/does_remove_const/does_type_preserve_const_correctness.cpp \ diff --git a/unit/analyses/constant_propagator.cpp b/unit/analyses/constant_propagator.cpp new file mode 100644 index 00000000000..8c591e1a63c --- /dev/null +++ b/unit/analyses/constant_propagator.cpp @@ -0,0 +1,114 @@ +/*******************************************************************\ + +Module: Unit test for constant propagation + +Author: Diffblue Ltd + +\*******************************************************************/ + +#include + +#include + +#include + +#include + +#include + +static bool starts_with_x(const exprt &e, const namespacet &) +{ + if(e.id() != ID_symbol) + return false; + return has_prefix(id2string(to_symbol_expr(e).get_identifier()), "x"); +} + +SCENARIO("constant_propagator", "[core][analyses][constant_propagator]") +{ + GIVEN("A simple GOTO program") + { + null_message_handlert null_out; + + goto_modelt goto_model; + namespacet ns(goto_model.symbol_table); + + // Create the program: + // int x = 1; + // int y = 2; + + symbolt local_x; + symbolt local_y; + local_x.name = "x"; + local_x.type = integer_typet(); + local_x.mode = ID_C; + local_y.name = "y"; + local_y.type = integer_typet(); + local_y.mode = ID_C; + + code_blockt code; + code.copy_to_operands(code_declt(local_x.symbol_expr())); + code.copy_to_operands(code_declt(local_y.symbol_expr())); + code.copy_to_operands( + code_assignt( + local_x.symbol_expr(), constant_exprt("1", integer_typet()))); + code.copy_to_operands( + code_assignt( + local_y.symbol_expr(), constant_exprt("2", integer_typet()))); + + symbolt main_function_symbol; + main_function_symbol.name = "main"; + main_function_symbol.type = code_typet(); + main_function_symbol.value = code; + main_function_symbol.mode = ID_C; + + goto_model.symbol_table.add(local_x); + goto_model.symbol_table.add(local_y); + goto_model.symbol_table.add(main_function_symbol); + + goto_convert(goto_model, null_out); + + const goto_functiont &main_function = goto_model.get_goto_function("main"); + + // Find the instruction after "y = 2;" + goto_programt::const_targett test_instruction = + main_function.body.instructions.begin(); + while( + test_instruction != main_function.body.instructions.end() && + (!test_instruction->is_assign() || + to_code_assign(test_instruction->code).lhs() != local_y.symbol_expr())) + { + ++test_instruction; + } + + REQUIRE(test_instruction != main_function.body.instructions.end()); + ++test_instruction; + + WHEN("We apply conventional constant propagation") + { + constant_propagator_ait constant_propagator(main_function); + constant_propagator(main_function, ns); + + THEN("The propagator should discover values for both 'x' and 'y'") + { + const auto &final_domain = constant_propagator[test_instruction]; + + REQUIRE(final_domain.values.is_constant(local_x.symbol_expr())); + REQUIRE(final_domain.values.is_constant(local_y.symbol_expr())); + } + } + + WHEN("We apply constant propagation for symbols beginning with 'x'") + { + constant_propagator_ait constant_propagator(main_function, starts_with_x); + constant_propagator(main_function, ns); + + THEN("The propagator should discover a value for 'x' but not 'y'") + { + const auto &final_domain = constant_propagator[test_instruction]; + + REQUIRE(final_domain.values.is_constant(local_x.symbol_expr())); + REQUIRE(!final_domain.values.is_constant(local_y.symbol_expr())); + } + } + } +}