Skip to content

Commit 78297a9

Browse files
committed
upstream rustc_codegen_llvm changes for enzyme/autodiff
1 parent 3fee0f1 commit 78297a9

File tree

13 files changed

+639
-29
lines changed

13 files changed

+639
-29
lines changed

Diff for: compiler/rustc_ast/src/expand/autodiff_attrs.rs

+3-16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9-
use crate::expand::typetree::TypeTree;
109
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1110
use crate::ptr::P;
1211
use crate::{Ty, TyKind};
@@ -79,10 +78,6 @@ pub struct AutoDiffItem {
7978
/// The name of the function being generated
8079
pub target: String,
8180
pub attrs: AutoDiffAttrs,
82-
/// Describe the memory layout of input types
83-
pub inputs: Vec<TypeTree>,
84-
/// Describe the memory layout of the output type
85-
pub output: TypeTree,
8681
}
8782
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8883
pub struct AutoDiffAttrs {
@@ -262,22 +257,14 @@ impl AutoDiffAttrs {
262257
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
263258
}
264259

265-
pub fn into_item(
266-
self,
267-
source: String,
268-
target: String,
269-
inputs: Vec<TypeTree>,
270-
output: TypeTree,
271-
) -> AutoDiffItem {
272-
AutoDiffItem { source, target, inputs, output, attrs: self }
260+
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
261+
AutoDiffItem { source, target, attrs: self }
273262
}
274263
}
275264

276265
impl fmt::Display for AutoDiffItem {
277266
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278267
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
279-
write!(f, " with attributes: {:?}", self.attrs)?;
280-
write!(f, " with inputs: {:?}", self.inputs)?;
281-
write!(f, " with output: {:?}", self.output)
268+
write!(f, " with attributes: {:?}", self.attrs)
282269
}
283270
}

Diff for: compiler/rustc_codegen_llvm/messages.ftl

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO
5656
codegen_llvm_run_passes = failed to run LLVM passes
5757
codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err}
5858
59+
codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error}
60+
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
61+
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
62+
5963
codegen_llvm_sanitizer_memtag_requires_mte =
6064
`-Zsanitizer=memtag` requires `-Ctarget-feature=+mte`
6165

Diff for: compiler/rustc_codegen_llvm/src/back/lto.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,12 @@ pub(crate) fn run_pass_manager(
604604
debug!("running the pass manager");
605605
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
606606
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
607-
unsafe { write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }?;
607+
// We will run this again with different values in the context of automatic differentiation.
608+
let first_run = true;
609+
debug!("running llvm pm opt pipeline");
610+
unsafe {
611+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
612+
}
608613
debug!("lto done");
609614
Ok(())
610615
}

Diff for: compiler/rustc_codegen_llvm/src/back/write.rs

+201-8
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
44
use std::sync::Arc;
55
use std::{fs, slice, str};
66

7-
use libc::{c_char, c_int, c_void, size_t};
7+
use libc::{c_char, c_int, c_uint, c_void, size_t};
88
use llvm::{
99
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols,
1010
};
11+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
1112
use rustc_codegen_ssa::back::link::ensure_removed;
1213
use rustc_codegen_ssa::back::versioned_llvm_target;
1314
use rustc_codegen_ssa::back::write::{
@@ -28,7 +29,7 @@ use rustc_session::config::{
2829
use rustc_span::InnerSpan;
2930
use rustc_span::symbol::sym;
3031
use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel};
31-
use tracing::debug;
32+
use tracing::{debug, trace};
3233

3334
use crate::back::lto::ThinBuffer;
3435
use crate::back::owned_target_machine::OwnedTargetMachine;
@@ -41,7 +42,13 @@ use crate::errors::{
4142
WithLlvmError, WriteBytecode,
4243
};
4344
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
44-
use crate::llvm::{self, DiagnosticInfo, PassManager};
45+
use crate::llvm::{
46+
self, AttributeKind, DiagnosticInfo, LLVMCreateStringAttribute, LLVMGetFirstFunction,
47+
LLVMGetNextFunction, LLVMGetStringAttributeAtIndex, LLVMIsEnumAttribute, LLVMIsStringAttribute,
48+
LLVMRemoveStringAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex,
49+
LLVMRustAddFunctionAttributes, LLVMRustGetEnumAttributeAtIndex,
50+
LLVMRustRemoveEnumAttributeAtIndex, PassManager,
51+
};
4552
use crate::type_::Type;
4653
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
4754

@@ -517,9 +524,34 @@ pub(crate) unsafe fn llvm_optimize(
517524
config: &ModuleConfig,
518525
opt_level: config::OptLevel,
519526
opt_stage: llvm::OptStage,
527+
skip_size_increasing_opts: bool,
520528
) -> Result<(), FatalError> {
521-
let unroll_loops =
522-
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
529+
// Enzyme:
530+
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
531+
// source code. However, benchmarks show that optimizations increasing the code size
532+
// tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
533+
// and finally re-optimize the module, now with all optimizations available.
534+
// TODO: In a future update we could figure out how to only optimize functions getting
535+
// differentiated.
536+
537+
let unroll_loops;
538+
let vectorize_slp;
539+
let vectorize_loop;
540+
541+
if skip_size_increasing_opts {
542+
unroll_loops = false;
543+
vectorize_slp = false;
544+
vectorize_loop = false;
545+
} else {
546+
unroll_loops =
547+
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
548+
vectorize_slp = config.vectorize_slp;
549+
vectorize_loop = config.vectorize_loop;
550+
}
551+
trace!(
552+
"Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}",
553+
unroll_loops, vectorize_slp, vectorize_loop
554+
);
523555
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
524556
let pgo_gen_path = get_pgo_gen_path(config);
525557
let pgo_use_path = get_pgo_use_path(config);
@@ -583,8 +615,8 @@ pub(crate) unsafe fn llvm_optimize(
583615
using_thin_buffers,
584616
config.merge_functions,
585617
unroll_loops,
586-
config.vectorize_slp,
587-
config.vectorize_loop,
618+
vectorize_slp,
619+
vectorize_loop,
588620
config.no_builtins,
589621
config.emit_lifetime_markers,
590622
sanitizer_options.as_ref(),
@@ -606,6 +638,113 @@ pub(crate) unsafe fn llvm_optimize(
606638
result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses))
607639
}
608640

641+
pub(crate) fn differentiate(
642+
module: &ModuleCodegen<ModuleLlvm>,
643+
cgcx: &CodegenContext<LlvmCodegenBackend>,
644+
diff_items: Vec<AutoDiffItem>,
645+
config: &ModuleConfig,
646+
) -> Result<(), FatalError> {
647+
for item in &diff_items {
648+
trace!("{}", item);
649+
}
650+
651+
let llmod = module.module_llvm.llmod();
652+
let llcx = &module.module_llvm.llcx;
653+
let diag_handler = cgcx.create_dcx();
654+
655+
// Before dumping the module, we want all the tt to become part of the module.
656+
for item in diff_items.iter() {
657+
let name = CString::new(item.source.clone()).unwrap();
658+
let fn_def: Option<&llvm::Value> =
659+
unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()) };
660+
let fn_def = match fn_def {
661+
Some(x) => x,
662+
None => {
663+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
664+
src: item.source.clone(),
665+
target: item.target.clone(),
666+
error: "could not find source function".to_owned(),
667+
}));
668+
}
669+
};
670+
let tgt_name = CString::new(item.target.clone()).unwrap();
671+
dbg!("Target name: {:?}", &tgt_name);
672+
let fn_target: Option<&llvm::Value> =
673+
unsafe { llvm::LLVMGetNamedFunction(llmod, tgt_name.as_ptr()) };
674+
let fn_target = match fn_target {
675+
Some(x) => x,
676+
None => {
677+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
678+
src: item.source.clone(),
679+
target: item.target.clone(),
680+
error: "could not find target function".to_owned(),
681+
}));
682+
}
683+
};
684+
685+
crate::builder::add_opt_dbg_helper2(llmod, llcx, fn_def, fn_target, item.attrs.clone());
686+
}
687+
688+
// We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
689+
// which Enzyme doesn't understand.
690+
unsafe {
691+
let mut f = LLVMGetFirstFunction(llmod);
692+
loop {
693+
if let Some(lf) = f {
694+
f = LLVMGetNextFunction(lf);
695+
let myhwattr = "enzyme_hw";
696+
let attr = LLVMGetStringAttributeAtIndex(
697+
lf,
698+
c_uint::MAX,
699+
myhwattr.as_ptr() as *const c_char,
700+
myhwattr.as_bytes().len() as c_uint,
701+
);
702+
if LLVMIsStringAttribute(attr) {
703+
LLVMRemoveStringAttributeAtIndex(
704+
lf,
705+
c_uint::MAX,
706+
myhwattr.as_ptr() as *const c_char,
707+
myhwattr.as_bytes().len() as c_uint,
708+
);
709+
} else {
710+
LLVMRustRemoveEnumAttributeAtIndex(
711+
lf,
712+
c_uint::MAX,
713+
AttributeKind::SanitizeHWAddress,
714+
);
715+
}
716+
} else {
717+
break;
718+
}
719+
}
720+
}
721+
722+
if let Some(opt_level) = config.opt_level {
723+
let opt_stage = match cgcx.lto {
724+
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
725+
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
726+
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
727+
_ => llvm::OptStage::PreLinkNoLTO,
728+
};
729+
let skip_size_increasing_opts = false;
730+
dbg!("Running Module Optimization after differentiation");
731+
unsafe {
732+
llvm_optimize(
733+
cgcx,
734+
diag_handler.handle(),
735+
module,
736+
config,
737+
opt_level,
738+
opt_stage,
739+
skip_size_increasing_opts,
740+
)?
741+
};
742+
}
743+
dbg!("Done with differentiate()");
744+
745+
Ok(())
746+
}
747+
609748
// Unsafe due to LLVM calls.
610749
pub(crate) unsafe fn optimize(
611750
cgcx: &CodegenContext<LlvmCodegenBackend>,
@@ -628,14 +767,68 @@ pub(crate) unsafe fn optimize(
628767
unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) };
629768
}
630769

770+
// This code enables Enzyme to differentiate code containing Rust enums.
771+
// By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
772+
// away the enums and allows Enzyme to understand why a value can be of different types in
773+
// different code sections. We remove this attribute after Enzyme is done, to not affect the
774+
// rest of the compilation.
775+
#[cfg(llvm_enzyme)]
776+
unsafe {
777+
let mut f = LLVMGetFirstFunction(llmod);
778+
loop {
779+
if let Some(lf) = f {
780+
f = LLVMGetNextFunction(lf);
781+
let myhwattr = "enzyme_hw";
782+
let myhwv = "";
783+
let prevattr = LLVMRustGetEnumAttributeAtIndex(
784+
lf,
785+
c_uint::MAX,
786+
AttributeKind::SanitizeHWAddress,
787+
);
788+
if LLVMIsEnumAttribute(prevattr) {
789+
let attr = LLVMCreateStringAttribute(
790+
llcx,
791+
myhwattr.as_ptr() as *const c_char,
792+
myhwattr.as_bytes().len() as c_uint,
793+
myhwv.as_ptr() as *const c_char,
794+
myhwv.as_bytes().len() as c_uint,
795+
);
796+
LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1);
797+
} else {
798+
LLVMRustAddEnumAttributeAtIndex(
799+
llcx,
800+
lf,
801+
c_uint::MAX,
802+
AttributeKind::SanitizeHWAddress,
803+
);
804+
}
805+
} else {
806+
break;
807+
}
808+
}
809+
}
810+
631811
if let Some(opt_level) = config.opt_level {
632812
let opt_stage = match cgcx.lto {
633813
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
634814
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
635815
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
636816
_ => llvm::OptStage::PreLinkNoLTO,
637817
};
638-
return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) };
818+
819+
// If we know that we will later run AD, then we disable vectorization and loop unrolling
820+
let skip_size_increasing_opts = cfg!(llvm_enzyme);
821+
return unsafe {
822+
llvm_optimize(
823+
cgcx,
824+
dcx,
825+
module,
826+
config,
827+
opt_level,
828+
opt_stage,
829+
skip_size_increasing_opts,
830+
)
831+
};
639832
}
640833
Ok(())
641834
}

0 commit comments

Comments
 (0)