@@ -359,30 +359,27 @@ mod llvm_enzyme {
359
359
ty
360
360
}
361
361
362
- /// We only want this function to type-check, since we will replace the body
363
- /// later on llvm level. Using `loop {}` does not cover all return types anymore,
364
- /// so instead we build something that should pass. We also add a inline_asm
365
- /// line, as one more barrier for rustc to prevent inlining of this function.
366
- /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
367
- /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
368
- /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
369
- /// this function (which should never happen, since it is only a placeholder).
370
- /// Finally, we also add back_box usages of all input arguments, to prevent rustc
371
- /// from optimizing any arguments away.
372
- fn gen_enzyme_body (
362
+ // Will generate a body of the type:
363
+ // ```
364
+ // {
365
+ // unsafe {
366
+ // asm!("NOP");
367
+ // }
368
+ // ::core::hint::black_box(primal(args));
369
+ // ::core::hint::black_box((args, ret));
370
+ // <This part remains to be done by following function>
371
+ // }
372
+ // ```
373
+ fn init_body_helper (
373
374
ecx : & ExtCtxt < ' _ > ,
374
- x : & AutoDiffAttrs ,
375
- n_active : u32 ,
376
- sig : & ast:: FnSig ,
377
- d_sig : & ast:: FnSig ,
375
+ span : Span ,
378
376
primal : Ident ,
379
377
new_names : & [ String ] ,
380
- span : Span ,
381
378
sig_span : Span ,
382
379
new_decl_span : Span ,
383
- idents : Vec < Ident > ,
380
+ idents : & [ Ident ] ,
384
381
errored : bool ,
385
- ) -> P < ast:: Block > {
382
+ ) -> ( P < ast:: Block > , P < ast :: Expr > , P < ast :: Expr > , P < ast :: Expr > ) {
386
383
let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
387
384
let noop = ast:: InlineAsm {
388
385
asm_macro : ast:: AsmMacro :: Asm ,
@@ -431,6 +428,54 @@ mod llvm_enzyme {
431
428
}
432
429
body. stmts . push ( ecx. stmt_semi ( black_box_remaining_args) ) ;
433
430
431
+ ( body, primal_call, black_box_primal_call, blackbox_call_expr)
432
+ }
433
+
434
+ /// We only want this function to type-check, since we will replace the body
435
+ /// later on llvm level. Using `loop {}` does not cover all return types anymore,
436
+ /// so instead we build something that should pass. We also add a inline_asm
437
+ /// line, as one more barrier for rustc to prevent inlining of this function.
438
+ /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
439
+ /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
440
+ /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
441
+ /// this function (which should never happen, since it is only a placeholder).
442
+ /// Finally, we also add back_box usages of all input arguments, to prevent rustc
443
+ /// from optimizing any arguments away.
444
+ fn gen_enzyme_body (
445
+ ecx : & ExtCtxt < ' _ > ,
446
+ x : & AutoDiffAttrs ,
447
+ n_active : u32 ,
448
+ sig : & ast:: FnSig ,
449
+ d_sig : & ast:: FnSig ,
450
+ primal : Ident ,
451
+ new_names : & [ String ] ,
452
+ span : Span ,
453
+ sig_span : Span ,
454
+ _new_decl_span : Span ,
455
+ idents : Vec < Ident > ,
456
+ errored : bool ,
457
+ ) -> P < ast:: Block > {
458
+ let new_decl_span = d_sig. span ;
459
+
460
+ // Just adding some default inline-asm and black_box usages to prevent early inlining
461
+ // and optimizations which alter the function signature.
462
+ //
463
+ // The bb_primal_call is the black_box call of the primal function. We keep it around,
464
+ // since it has the convenient property of returning the type of the primal function,
465
+ // Remember, we only care to match types here.
466
+ // No matter which return we pick, we always wrap it into a std::hint::black_box call,
467
+ // to prevent rustc from propagating it into the caller.
468
+ let ( mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper (
469
+ ecx,
470
+ span,
471
+ primal,
472
+ new_names,
473
+ sig_span,
474
+ new_decl_span,
475
+ & idents,
476
+ errored,
477
+ ) ;
478
+
434
479
if !has_ret ( & d_sig. decl . output ) {
435
480
// there is no return type that we have to match, () works fine.
436
481
return body;
@@ -442,7 +487,7 @@ mod llvm_enzyme {
442
487
443
488
if primal_ret && n_active == 0 && x. mode . is_rev ( ) {
444
489
// We only have the primal ret.
445
- body. stmts . push ( ecx. stmt_expr ( black_box_primal_call ) ) ;
490
+ body. stmts . push ( ecx. stmt_expr ( bb_primal_call ) ) ;
446
491
return body;
447
492
}
448
493
@@ -534,11 +579,11 @@ mod llvm_enzyme {
534
579
return body;
535
580
}
536
581
[ arg] => {
537
- ret = ecx. expr_call ( new_decl_span, blackbox_call_expr , thin_vec ! [ arg. clone( ) ] ) ;
582
+ ret = ecx. expr_call ( new_decl_span, bb_call_expr , thin_vec ! [ arg. clone( ) ] ) ;
538
583
}
539
584
args => {
540
585
let ret_tuple: P < ast:: Expr > = ecx. expr_tuple ( span, args. into ( ) ) ;
541
- ret = ecx. expr_call ( new_decl_span, blackbox_call_expr , thin_vec ! [ ret_tuple] ) ;
586
+ ret = ecx. expr_call ( new_decl_span, bb_call_expr , thin_vec ! [ ret_tuple] ) ;
542
587
}
543
588
}
544
589
assert ! ( has_ret( & d_sig. decl. output) ) ;
@@ -551,7 +596,7 @@ mod llvm_enzyme {
551
596
ecx : & ExtCtxt < ' _ > ,
552
597
span : Span ,
553
598
primal : Ident ,
554
- idents : Vec < Ident > ,
599
+ idents : & [ Ident ] ,
555
600
) -> P < ast:: Expr > {
556
601
let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
557
602
if has_self {
0 commit comments