@@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
4
4
use std:: sync:: Arc ;
5
5
use std:: { fs, slice, str} ;
6
6
7
- use libc:: { c_char, c_int, c_void, size_t} ;
7
+ use libc:: { c_char, c_int, c_uint , c_void, size_t} ;
8
8
use llvm:: {
9
9
LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
10
10
} ;
11
+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
11
12
use rustc_codegen_ssa:: back:: link:: ensure_removed;
12
13
use rustc_codegen_ssa:: back:: versioned_llvm_target;
13
14
use rustc_codegen_ssa:: back:: write:: {
@@ -28,7 +29,7 @@ use rustc_session::config::{
28
29
use rustc_span:: InnerSpan ;
29
30
use rustc_span:: symbol:: sym;
30
31
use rustc_target:: spec:: { CodeModel , RelocModel , SanitizerSet , SplitDebuginfo , TlsModel } ;
31
- use tracing:: debug;
32
+ use tracing:: { debug, trace } ;
32
33
33
34
use crate :: back:: lto:: ThinBuffer ;
34
35
use crate :: back:: owned_target_machine:: OwnedTargetMachine ;
@@ -41,7 +42,13 @@ use crate::errors::{
41
42
WithLlvmError , WriteBytecode ,
42
43
} ;
43
44
use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
44
- use crate :: llvm:: { self , DiagnosticInfo , PassManager } ;
45
+ use crate :: llvm:: {
46
+ self , AttributeKind , DiagnosticInfo , LLVMCreateStringAttribute , LLVMGetFirstFunction ,
47
+ LLVMGetNextFunction , LLVMGetStringAttributeAtIndex , LLVMIsEnumAttribute , LLVMIsStringAttribute ,
48
+ LLVMRemoveStringAttributeAtIndex , LLVMRustAddEnumAttributeAtIndex ,
49
+ LLVMRustAddFunctionAttributes , LLVMRustGetEnumAttributeAtIndex ,
50
+ LLVMRustRemoveEnumAttributeAtIndex , PassManager ,
51
+ } ;
45
52
use crate :: type_:: Type ;
46
53
use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
47
54
@@ -517,9 +524,34 @@ pub(crate) unsafe fn llvm_optimize(
517
524
config : & ModuleConfig ,
518
525
opt_level : config:: OptLevel ,
519
526
opt_stage : llvm:: OptStage ,
527
+ skip_size_increasing_opts : bool ,
520
528
) -> Result < ( ) , FatalError > {
521
- let unroll_loops =
522
- opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
529
+ // Enzyme:
530
+ // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
531
+ // source code. However, benchmarks show that optimizations increasing the code size
532
+ // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
533
+ // and finally re-optimize the module, now with all optimizations available.
534
+ // TODO: In a future update we could figure out how to only optimize functions getting
535
+ // differentiated.
536
+
537
+ let unroll_loops;
538
+ let vectorize_slp;
539
+ let vectorize_loop;
540
+
541
+ if skip_size_increasing_opts {
542
+ unroll_loops = false ;
543
+ vectorize_slp = false ;
544
+ vectorize_loop = false ;
545
+ } else {
546
+ unroll_loops =
547
+ opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
548
+ vectorize_slp = config. vectorize_slp ;
549
+ vectorize_loop = config. vectorize_loop ;
550
+ }
551
+ trace ! (
552
+ "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}" ,
553
+ unroll_loops, vectorize_slp, vectorize_loop
554
+ ) ;
523
555
let using_thin_buffers = opt_stage == llvm:: OptStage :: PreLinkThinLTO || config. bitcode_needed ( ) ;
524
556
let pgo_gen_path = get_pgo_gen_path ( config) ;
525
557
let pgo_use_path = get_pgo_use_path ( config) ;
@@ -583,8 +615,8 @@ pub(crate) unsafe fn llvm_optimize(
583
615
using_thin_buffers,
584
616
config. merge_functions ,
585
617
unroll_loops,
586
- config . vectorize_slp ,
587
- config . vectorize_loop ,
618
+ vectorize_slp,
619
+ vectorize_loop,
588
620
config. no_builtins ,
589
621
config. emit_lifetime_markers ,
590
622
sanitizer_options. as_ref ( ) ,
@@ -606,6 +638,113 @@ pub(crate) unsafe fn llvm_optimize(
606
638
result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
607
639
}
608
640
641
+ pub ( crate ) fn differentiate (
642
+ module : & ModuleCodegen < ModuleLlvm > ,
643
+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
644
+ diff_items : Vec < AutoDiffItem > ,
645
+ config : & ModuleConfig ,
646
+ ) -> Result < ( ) , FatalError > {
647
+ for item in & diff_items {
648
+ trace ! ( "{}" , item) ;
649
+ }
650
+
651
+ let llmod = module. module_llvm . llmod ( ) ;
652
+ let llcx = & module. module_llvm . llcx ;
653
+ let diag_handler = cgcx. create_dcx ( ) ;
654
+
655
+ // Before dumping the module, we want all the tt to become part of the module.
656
+ for item in diff_items. iter ( ) {
657
+ let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
658
+ let fn_def: Option < & llvm:: Value > =
659
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) } ;
660
+ let fn_def = match fn_def {
661
+ Some ( x) => x,
662
+ None => {
663
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
664
+ src : item. source . clone ( ) ,
665
+ target : item. target . clone ( ) ,
666
+ error : "could not find source function" . to_owned ( ) ,
667
+ } ) ) ;
668
+ }
669
+ } ;
670
+ let tgt_name = CString :: new ( item. target . clone ( ) ) . unwrap ( ) ;
671
+ dbg ! ( "Target name: {:?}" , & tgt_name) ;
672
+ let fn_target: Option < & llvm:: Value > =
673
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, tgt_name. as_ptr ( ) ) } ;
674
+ let fn_target = match fn_target {
675
+ Some ( x) => x,
676
+ None => {
677
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
678
+ src : item. source . clone ( ) ,
679
+ target : item. target . clone ( ) ,
680
+ error : "could not find target function" . to_owned ( ) ,
681
+ } ) ) ;
682
+ }
683
+ } ;
684
+
685
+ crate :: builder:: add_opt_dbg_helper2 ( llmod, llcx, fn_def, fn_target, item. attrs . clone ( ) ) ;
686
+ }
687
+
688
+ // We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
689
+ // which Enzyme doesn't understand.
690
+ unsafe {
691
+ let mut f = LLVMGetFirstFunction ( llmod) ;
692
+ loop {
693
+ if let Some ( lf) = f {
694
+ f = LLVMGetNextFunction ( lf) ;
695
+ let myhwattr = "enzyme_hw" ;
696
+ let attr = LLVMGetStringAttributeAtIndex (
697
+ lf,
698
+ c_uint:: MAX ,
699
+ myhwattr. as_ptr ( ) as * const c_char ,
700
+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
701
+ ) ;
702
+ if LLVMIsStringAttribute ( attr) {
703
+ LLVMRemoveStringAttributeAtIndex (
704
+ lf,
705
+ c_uint:: MAX ,
706
+ myhwattr. as_ptr ( ) as * const c_char ,
707
+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
708
+ ) ;
709
+ } else {
710
+ LLVMRustRemoveEnumAttributeAtIndex (
711
+ lf,
712
+ c_uint:: MAX ,
713
+ AttributeKind :: SanitizeHWAddress ,
714
+ ) ;
715
+ }
716
+ } else {
717
+ break ;
718
+ }
719
+ }
720
+ }
721
+
722
+ if let Some ( opt_level) = config. opt_level {
723
+ let opt_stage = match cgcx. lto {
724
+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
725
+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
726
+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
727
+ _ => llvm:: OptStage :: PreLinkNoLTO ,
728
+ } ;
729
+ let skip_size_increasing_opts = false ;
730
+ dbg ! ( "Running Module Optimization after differentiation" ) ;
731
+ unsafe {
732
+ llvm_optimize (
733
+ cgcx,
734
+ diag_handler. handle ( ) ,
735
+ module,
736
+ config,
737
+ opt_level,
738
+ opt_stage,
739
+ skip_size_increasing_opts,
740
+ ) ?
741
+ } ;
742
+ }
743
+ dbg ! ( "Done with differentiate()" ) ;
744
+
745
+ Ok ( ( ) )
746
+ }
747
+
609
748
// Unsafe due to LLVM calls.
610
749
pub ( crate ) unsafe fn optimize (
611
750
cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -628,14 +767,68 @@ pub(crate) unsafe fn optimize(
628
767
unsafe { llvm:: LLVMWriteBitcodeToFile ( llmod, out. as_ptr ( ) ) } ;
629
768
}
630
769
770
+ // This code enables Enzyme to differentiate code containing Rust enums.
771
+ // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
772
+ // away the enums and allows Enzyme to understand why a value can be of different types in
773
+ // different code sections. We remove this attribute after Enzyme is done, to not affect the
774
+ // rest of the compilation.
775
+ #[ cfg( llvm_enzyme) ]
776
+ unsafe {
777
+ let mut f = LLVMGetFirstFunction ( llmod) ;
778
+ loop {
779
+ if let Some ( lf) = f {
780
+ f = LLVMGetNextFunction ( lf) ;
781
+ let myhwattr = "enzyme_hw" ;
782
+ let myhwv = "" ;
783
+ let prevattr = LLVMRustGetEnumAttributeAtIndex (
784
+ lf,
785
+ c_uint:: MAX ,
786
+ AttributeKind :: SanitizeHWAddress ,
787
+ ) ;
788
+ if LLVMIsEnumAttribute ( prevattr) {
789
+ let attr = LLVMCreateStringAttribute (
790
+ llcx,
791
+ myhwattr. as_ptr ( ) as * const c_char ,
792
+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
793
+ myhwv. as_ptr ( ) as * const c_char ,
794
+ myhwv. as_bytes ( ) . len ( ) as c_uint ,
795
+ ) ;
796
+ LLVMRustAddFunctionAttributes ( lf, c_uint:: MAX , & attr, 1 ) ;
797
+ } else {
798
+ LLVMRustAddEnumAttributeAtIndex (
799
+ llcx,
800
+ lf,
801
+ c_uint:: MAX ,
802
+ AttributeKind :: SanitizeHWAddress ,
803
+ ) ;
804
+ }
805
+ } else {
806
+ break ;
807
+ }
808
+ }
809
+ }
810
+
631
811
if let Some ( opt_level) = config. opt_level {
632
812
let opt_stage = match cgcx. lto {
633
813
Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
634
814
Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
635
815
_ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
636
816
_ => llvm:: OptStage :: PreLinkNoLTO ,
637
817
} ;
638
- return unsafe { llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage) } ;
818
+
819
+ // If we know that we will later run AD, then we disable vectorization and loop unrolling
820
+ let skip_size_increasing_opts = cfg ! ( llvm_enzyme) ;
821
+ return unsafe {
822
+ llvm_optimize (
823
+ cgcx,
824
+ dcx,
825
+ module,
826
+ config,
827
+ opt_level,
828
+ opt_stage,
829
+ skip_size_increasing_opts,
830
+ )
831
+ } ;
639
832
}
640
833
Ok ( ( ) )
641
834
}
0 commit comments