@@ -57,7 +57,7 @@ class remove_virtual_functionst
57
57
const symbol_exprt &,
58
58
const irep_idt &,
59
59
dispatch_table_entriest &,
60
- std::set<irep_idt> &visited ,
60
+ dispatch_table_entries_mapt & ,
61
61
const function_call_resolvert &) const ;
62
62
exprt
63
63
get_method (const irep_idt &class_id, const irep_idt &component_name) const ;
@@ -163,11 +163,18 @@ void remove_virtual_functionst::remove_virtual_function(
163
163
newinst->source_location =vcall_source_loc;
164
164
}
165
165
166
+ // get initial identifier for grouping
167
+ INVARIANT (!functions.empty (), " Function dispatch table cannot be empty." );
168
+ auto last_id = functions.back ().symbol_expr .get_identifier ();
169
+ // record class_ids for disjunction
170
+ std::set<irep_idt> class_ids;
171
+
166
172
std::map<irep_idt, goto_programt::targett> calls;
167
173
// Note backwards iteration, to get the fallback candidate first.
168
174
for (auto it=functions.crbegin (), itend=functions.crend (); it!=itend; ++it)
169
175
{
170
176
const auto &fun=*it;
177
+ class_ids.insert (fun.class_id );
171
178
auto insertit=calls.insert (
172
179
{fun.symbol_expr .get_identifier (), goto_programt::targett ()});
173
180
@@ -209,15 +216,50 @@ void remove_virtual_functionst::remove_virtual_function(
209
216
t3->make_goto (t_final, true_exprt ());
210
217
}
211
218
219
+ // Emit target if end of dispatch table is reached or if the next element is
220
+ // dispatched to another function call. Assumes entries in the functions
221
+ // variable to be sorted for the identifier of the function to be called.
222
+ auto l_it = std::next (it);
223
+ bool next_emit_target =
224
+ (l_it == functions.crend ()) ||
225
+ l_it->symbol_expr .get_identifier () != fun.symbol_expr .get_identifier ();
226
+
227
+ // The root function call is done via fall-through, so nothing to emit
228
+ // explicitly for this.
229
+ if (next_emit_target && fun.symbol_expr == last_function_symbol)
230
+ {
231
+ class_ids.clear ();
232
+ }
233
+
212
234
// If this calls the fallback function we just fall through.
213
235
// Otherwise branch to the right call:
214
236
if (fallback_action!=virtual_dispatch_fallback_actiont::CALL_LAST_FUNCTION ||
215
237
fun.symbol_expr !=last_function_symbol)
216
238
{
217
- exprt c_id1=constant_exprt (fun.class_id , string_typet ());
218
- goto_programt::targett t4=new_code_gotos.add_instruction ();
219
- t4->source_location =vcall_source_loc;
220
- t4->make_goto (insertit.first ->second , equal_exprt (c_id1, c_id2));
239
+ // create a disjunction of class_ids to test
240
+ if (next_emit_target && fun.symbol_expr != last_function_symbol)
241
+ {
242
+ exprt::operandst or_ops;
243
+ for (const auto &id : class_ids)
244
+ {
245
+ const constant_exprt c_id1 (id, string_typet ());
246
+ const equal_exprt class_id_test (c_id1, c_id2);
247
+ or_ops.push_back (class_id_test);
248
+ }
249
+
250
+ goto_programt::targett t4 = new_code_gotos.add_instruction ();
251
+ t4->source_location = vcall_source_loc;
252
+ t4->make_goto (insertit.first ->second , disjunction (or_ops));
253
+
254
+ last_id = fun.symbol_expr .get_identifier ();
255
+ class_ids.clear ();
256
+ }
257
+ // record class_id
258
+ else if (next_emit_target)
259
+ {
260
+ last_id = fun.symbol_expr .get_identifier ();
261
+ class_ids.clear ();
262
+ }
221
263
}
222
264
}
223
265
@@ -252,11 +294,12 @@ void remove_virtual_functionst::remove_virtual_function(
252
294
253
295
// / Used by get_functions to track the most-derived parent that provides an
254
296
// / override of a given function.
255
- // / \par parameters: `this_id`: class name
256
- // / `last_method_defn`: the most-derived parent of `this_id` to define the
257
- // / requested function
258
- // / `component_name`: name of the function searched for
259
- // / `resolve_function_call`: function to resolve abstract method call
297
+ // / \param parameters: `this_id`: class name
298
+ // / \param `last_method_defn`: the most-derived parent of `this_id` to define
299
+ // / the requested function
300
+ // / \param `component_name`: name of the function searched for
301
+ // / \param `entry_map`: map of class identifiers to dispatch table entries
302
+ // / \param `resolve_function_call`: function to resolve abstract method call
260
303
// / \return `functions` is assigned a list of {class name, function symbol}
261
304
// / pairs indicating that if `this` is of the given class, then the call will
262
305
// / target the given function. Thus if A <: B <: C and A and C provide
@@ -267,7 +310,7 @@ void remove_virtual_functionst::get_child_functions_rec(
267
310
const symbol_exprt &last_method_defn,
268
311
const irep_idt &component_name,
269
312
dispatch_table_entriest &functions,
270
- std::set<irep_idt> &visited ,
313
+ dispatch_table_entries_mapt &entry_map ,
271
314
const function_call_resolvert &resolve_function_call) const
272
315
{
273
316
auto findit=class_hierarchy.class_map .find (this_id);
@@ -276,9 +319,18 @@ void remove_virtual_functionst::get_child_functions_rec(
276
319
277
320
for (const auto &child : findit->second .children )
278
321
{
279
- if (!visited.insert (child).second )
322
+ // Skip if we have already visited this and we found a function call that
323
+ // did not resolve to non java.lang.Object.
324
+ auto it = entry_map.find (child);
325
+ if (
326
+ it != entry_map.end () &&
327
+ !has_prefix (
328
+ id2string (it->second .symbol_expr .get_identifier ()),
329
+ " java::java.lang.Object" ))
330
+ {
280
331
continue ;
281
- exprt method=get_method (child, component_name);
332
+ }
333
+ exprt method = get_method (child, component_name);
282
334
dispatch_table_entryt function (child);
283
335
if (method.is_not_nil ())
284
336
{
@@ -305,37 +357,43 @@ void remove_virtual_functionst::get_child_functions_rec(
305
357
}
306
358
}
307
359
functions.push_back (function);
360
+ entry_map.insert ({child, function});
308
361
309
362
get_child_functions_rec (
310
363
child,
311
364
function.symbol_expr ,
312
365
component_name,
313
366
functions,
314
- visited ,
367
+ entry_map ,
315
368
resolve_function_call);
316
369
}
317
370
}
318
371
372
+ // / Used to get dispatch entries to call for the given function
373
+ // / \param function: function that should be called
374
+ // / \param[out] functions: is assigned a list of dispatch entries, i.e., pairs
375
+ // / of class names and function symbol to call when encountering the class.
319
376
void remove_virtual_functionst::get_functions (
320
377
const exprt &function,
321
378
dispatch_table_entriest &functions)
322
379
{
380
+ // class part of function to call
323
381
const irep_idt class_id=function.get (ID_C_class);
324
382
const std::string class_id_string (id2string (class_id));
325
- const irep_idt component_name= function.get (ID_component_name);
326
- const std::string component_name_string (id2string (component_name ));
383
+ const irep_idt function_name = function.get (ID_component_name);
384
+ const std::string function_name_string (id2string (function_name ));
327
385
INVARIANT (!class_id.empty (), " All virtual functions must have a class" );
328
386
329
387
resolve_concrete_function_callt get_virtual_call_target (
330
388
symbol_table, class_hierarchy);
331
389
const function_call_resolvert resolve_function_call =
332
390
[&get_virtual_call_target](
333
- const irep_idt &class_id, const irep_idt &component_name ) {
334
- return get_virtual_call_target (class_id, component_name );
391
+ const irep_idt &class_id, const irep_idt &function_name ) {
392
+ return get_virtual_call_target (class_id, function_name );
335
393
};
336
394
337
395
const resolve_concrete_function_callt::concrete_function_callt
338
- &resolved_call = get_virtual_call_target (class_id, component_name );
396
+ &resolved_call = get_virtual_call_target (class_id, function_name );
339
397
340
398
dispatch_table_entryt root_function;
341
399
@@ -357,17 +415,37 @@ void remove_virtual_functionst::get_functions(
357
415
}
358
416
359
417
// iterate over all children, transitively
360
- std::set<irep_idt> visited ;
418
+ dispatch_table_entries_mapt entry_map ;
361
419
get_child_functions_rec (
362
420
class_id,
363
421
root_function.symbol_expr ,
364
- component_name ,
422
+ function_name ,
365
423
functions,
366
- visited ,
424
+ entry_map ,
367
425
resolve_function_call);
368
426
369
427
if (root_function.symbol_expr !=symbol_exprt ())
370
428
functions.push_back (root_function);
429
+
430
+ // Sort for the identifier of the function call symbol expression, grouping
431
+ // together calls to the same function. Keep java.lang.Object entries at the
432
+ // end for fall through. The reasoning is that this is the case with most
433
+ // entries in realistic cases.
434
+ std::sort (
435
+ functions.begin (),
436
+ functions.end (),
437
+ [&root_function](const dispatch_table_entryt &a, dispatch_table_entryt &b) {
438
+ if (
439
+ has_prefix (
440
+ id2string (a.symbol_expr .get_identifier ()), " java::java.lang.Object" ))
441
+ return false ;
442
+ else if (
443
+ has_prefix (
444
+ id2string (b.symbol_expr .get_identifier ()), " java::java.lang.Object" ))
445
+ return true ;
446
+ else
447
+ return a.symbol_expr .get_identifier () < b.symbol_expr .get_identifier ();
448
+ });
371
449
}
372
450
373
451
exprt remove_virtual_functionst::get_method (
0 commit comments