Skip to content

Commit e2d250c

Browse files
committed
update autodiff flags
1 parent 161a4bf commit e2d250c

File tree

11 files changed

+203
-75
lines changed

11 files changed

+203
-75
lines changed

Diff for: compiler/rustc_codegen_llvm/messages.ftl

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
codegen_llvm_autodiff_without_enable = using the autodiff feature requires -Z autodiff=Enable
12
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
23
34
codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err}

Diff for: compiler/rustc_codegen_llvm/src/back/lto.rs

+62-23
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,42 @@ fn thin_lto(
586586
}
587587
}
588588

589+
fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
590+
for &val in ad {
591+
match val {
592+
config::AutoDiff::PrintModBefore => {
593+
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
594+
}
595+
config::AutoDiff::PrintPerf => {
596+
llvm::set_print_perf(true);
597+
}
598+
config::AutoDiff::PrintAA => {
599+
llvm::set_print_activity(true);
600+
}
601+
config::AutoDiff::PrintTA => {
602+
llvm::set_print_type(true);
603+
}
604+
config::AutoDiff::Inline => {
605+
llvm::set_inline(true);
606+
}
607+
config::AutoDiff::LooseTypes => {
608+
llvm::set_loose_types(false);
609+
}
610+
config::AutoDiff::PrintSteps => {
611+
llvm::set_print(true);
612+
}
613+
// We handle this below
614+
config::AutoDiff::PrintModAfter => {}
615+
// This is required and already checked
616+
config::AutoDiff::Enable => {}
617+
}
618+
}
619+
// This helps with handling enums for now.
620+
llvm::set_strict_aliasing(false);
621+
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
622+
llvm::set_rust_rules(true);
623+
}
624+
589625
pub(crate) fn run_pass_manager(
590626
cgcx: &CodegenContext<LlvmCodegenBackend>,
591627
dcx: DiagCtxtHandle<'_>,
@@ -604,34 +640,37 @@ pub(crate) fn run_pass_manager(
604640
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
605641
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
606642

607-
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
608-
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
609-
debug!("running llvm pm opt pipeline");
643+
// The PostAD behavior is the same that we would have if no autodiff was used.
644+
// It will run the default optimization pipeline. If AD is enabled we select
645+
// the DuringAD stage, which will disable vectorization and loop unrolling, and
646+
// schedule two autodiff optimization + differentiation passes.
647+
// We then run the llvm_optimize function a second time, to optimize the code which we generated
648+
// in the enzyme differentiation pass.
649+
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
650+
let stage =
651+
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
652+
653+
if enable_ad {
654+
enable_autodiff_settings(&config.autodiff, module);
655+
}
656+
610657
unsafe {
611-
write::llvm_optimize(
612-
cgcx,
613-
dcx,
614-
module,
615-
config,
616-
opt_level,
617-
opt_stage,
618-
write::AutodiffStage::DuringAD,
619-
)?;
658+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
620659
}
621-
// FIXME(ZuseZ4): Make this more granular
622-
if cfg!(llvm_enzyme) && !thin {
660+
661+
if cfg!(llvm_enzyme) && enable_ad {
662+
let opt_stage = llvm::OptStage::FatLTO;
663+
let stage = write::AutodiffStage::PostAD;
623664
unsafe {
624-
write::llvm_optimize(
625-
cgcx,
626-
dcx,
627-
module,
628-
config,
629-
opt_level,
630-
llvm::OptStage::FatLTO,
631-
write::AutodiffStage::PostAD,
632-
)?;
665+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
666+
}
667+
668+
// This is the final IR, so people should be able to inspect the optimized autodiff output.
669+
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
670+
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
633671
}
634672
}
673+
635674
debug!("lto done");
636675
Ok(())
637676
}

Diff for: compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::back::write::llvm_err;
1010
use crate::builder::SBuilder;
1111
use crate::context::SimpleCx;
1212
use crate::declare::declare_simple_fn;
13-
use crate::errors::LlvmError;
13+
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
1414
use crate::llvm::AttributePlace::Function;
1515
use crate::llvm::{Metadata, True};
1616
use crate::value::Value;
@@ -46,9 +46,6 @@ fn generate_enzyme_call<'ll>(
4646
let output = attrs.ret_activity;
4747

4848
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
49-
// FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
50-
// it will handle higher-order derivatives correctly automatically (in theory). Currently
51-
// higher-order derivatives fail, so we should debug that before adjusting this code.
5249
let mut ad_name: String = match attrs.mode {
5350
DiffMode::Forward => "__enzyme_fwddiff",
5451
DiffMode::Reverse => "__enzyme_autodiff",
@@ -291,6 +288,14 @@ pub(crate) fn differentiate<'ll>(
291288
let diag_handler = cgcx.create_dcx();
292289
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
293290

291+
// First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
292+
if !diff_items.is_empty()
293+
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
294+
{
295+
let dcx = cgcx.create_dcx();
296+
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
297+
}
298+
294299
// Before dumping the module, we want all the TypeTrees to become part of the module.
295300
for item in diff_items.iter() {
296301
let name = item.source.clone();

Diff for: compiler/rustc_codegen_llvm/src/errors.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,12 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {
9292

9393
#[derive(Diagnostic)]
9494
#[diag(codegen_llvm_autodiff_without_lto)]
95-
#[note]
9695
pub(crate) struct AutoDiffWithoutLTO;
9796

97+
#[derive(Diagnostic)]
98+
#[diag(codegen_llvm_autodiff_without_enable)]
99+
pub(crate) struct AutoDiffWithoutEnable;
100+
98101
#[derive(Diagnostic)]
99102
#[diag(codegen_llvm_lto_disallowed)]
100103
pub(crate) struct LtoDisallowed;

Diff for: compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

+94
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,97 @@ pub enum LLVMRustVerifierFailureAction {
3535
LLVMPrintMessageAction = 1,
3636
LLVMReturnStatusAction = 2,
3737
}
38+
39+
#[cfg(llvm_enzyme)]
40+
pub use self::Enzyme_AD::*;
41+
42+
#[cfg(llvm_enzyme)]
43+
pub mod Enzyme_AD {
44+
use libc::c_void;
45+
extern "C" {
46+
pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
47+
}
48+
extern "C" {
49+
static mut EnzymePrintPerf: c_void;
50+
static mut EnzymePrintActivity: c_void;
51+
static mut EnzymePrintType: c_void;
52+
static mut EnzymePrint: c_void;
53+
static mut EnzymeStrictAliasing: c_void;
54+
static mut looseTypeAnalysis: c_void;
55+
static mut EnzymeInline: c_void;
56+
static mut RustTypeRules: c_void;
57+
}
58+
pub fn set_print_perf(print: bool) {
59+
unsafe {
60+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
61+
}
62+
}
63+
pub fn set_print_activity(print: bool) {
64+
unsafe {
65+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
66+
}
67+
}
68+
pub fn set_print_type(print: bool) {
69+
unsafe {
70+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
71+
}
72+
}
73+
pub fn set_print(print: bool) {
74+
unsafe {
75+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
76+
}
77+
}
78+
pub fn set_strict_aliasing(strict: bool) {
79+
unsafe {
80+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
81+
}
82+
}
83+
pub fn set_loose_types(loose: bool) {
84+
unsafe {
85+
EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8);
86+
}
87+
}
88+
pub fn set_inline(val: bool) {
89+
unsafe {
90+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8);
91+
}
92+
}
93+
pub fn set_rust_rules(val: bool) {
94+
unsafe {
95+
EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8);
96+
}
97+
}
98+
}
99+
100+
#[cfg(not(llvm_enzyme))]
101+
pub use self::Fallback_AD::*;
102+
103+
#[cfg(not(llvm_enzyme))]
104+
pub mod Fallback_AD {
105+
#![allow(unused_variables)]
106+
107+
pub fn set_inline(val: bool) {
108+
unimplemented!()
109+
}
110+
pub fn set_print_perf(print: bool) {
111+
unimplemented!()
112+
}
113+
pub fn set_print_activity(print: bool) {
114+
unimplemented!()
115+
}
116+
pub fn set_print_type(print: bool) {
117+
unimplemented!()
118+
}
119+
pub fn set_print(print: bool) {
120+
unimplemented!()
121+
}
122+
pub fn set_strict_aliasing(strict: bool) {
123+
unimplemented!()
124+
}
125+
pub fn set_loose_types(loose: bool) {
126+
unimplemented!()
127+
}
128+
pub fn set_rust_rules(val: bool) {
129+
unimplemented!()
130+
}
131+
}

Diff for: compiler/rustc_codegen_ssa/src/back/write.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ fn generate_lto_work<B: ExtraBackendMethods>(
405405
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
406406
if cgcx.lto == Lto::Fat && !autodiff.is_empty() {
407407
let config = cgcx.config(ModuleKind::Regular);
408-
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
408+
module =
409+
unsafe { module.autodiff(cgcx, autodiff, config).unwrap_or_else(|e| e.raise()) };
409410
}
410411
// We are adding a single work item, so the cost doesn't matter.
411412
vec![(WorkItem::LTO(module), 0)]

Diff for: compiler/rustc_interface/src/tests.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ fn test_unstable_options_tracking_hash() {
759759
tracked!(allow_features, Some(vec![String::from("lang_items")]));
760760
tracked!(always_encode_mir, true);
761761
tracked!(assume_incomplete_release, true);
762-
tracked!(autodiff, vec![AutoDiff::Print]);
762+
tracked!(autodiff, vec![AutoDiff::Enable]);
763763
tracked!(binary_dep_depinfo, true);
764764
tracked!(box_noalias, false);
765765
tracked!(

Diff for: compiler/rustc_session/src/config.rs

+9-16
Original file line numberDiff line numberDiff line change
@@ -198,33 +198,26 @@ pub enum CoverageLevel {
198198
/// The different settings that the `-Z autodiff` flag can have.
199199
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
200200
pub enum AutoDiff {
201+
/// Enable the autodiff opt pipeline
202+
Enable,
203+
201204
/// Print TypeAnalysis information
202205
PrintTA,
203206
/// Print ActivityAnalysis Information
204207
PrintAA,
205208
/// Print Performance Warnings from Enzyme
206209
PrintPerf,
207-
/// Combines the three print flags above.
208-
Print,
210+
/// Print intermediate IR generation steps
211+
PrintSteps,
209212
/// Print the whole module, before running opts.
210213
PrintModBefore,
211-
/// Print the whole module just before we pass it to Enzyme.
212-
/// For Debug purpose, prefer the OPT flag below
213-
PrintModAfterOpts,
214214
/// Print the module after Enzyme differentiated everything.
215-
PrintModAfterEnzyme,
215+
PrintModAfter,
216216

217-
/// Enzyme's loose type debug helper (can cause incorrect gradients)
217+
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
218+
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
218219
LooseTypes,
219-
220-
/// More flags
221-
NoModOptAfter,
222-
/// Tell Enzyme to run LLVM Opts on each function it generated. By default off,
223-
/// since we already optimize the whole module after Enzyme is done.
224-
EnableFncOpt,
225-
NoVecUnroll,
226-
RuntimeActivity,
227-
/// Runs Enzyme specific Inlining
220+
/// Runs Enzyme's aggressive inlining
228221
Inline,
229222
}
230223

Diff for: compiler/rustc_session/src/options.rs

+17-22
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ mod desc {
707707
pub(crate) const parse_list: &str = "a space-separated list of strings";
708708
pub(crate) const parse_list_with_polarity: &str =
709709
"a comma-separated list of strings, with elements beginning with + or -";
710-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Print`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfterOpts`, `PrintModAfterEnzyme`, `LooseTypes`, `NoModOptAfter`, `EnableFncOpt`, `NoVecUnroll`, `Inline`";
710+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `LooseTypes`, `Inline`";
711711
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
712712
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
713713
pub(crate) const parse_number: &str = "a number";
@@ -1348,17 +1348,14 @@ pub mod parse {
13481348
v.sort_unstable();
13491349
for &val in v.iter() {
13501350
let variant = match val {
1351+
"Enable" => AutoDiff::Enable,
13511352
"PrintTA" => AutoDiff::PrintTA,
13521353
"PrintAA" => AutoDiff::PrintAA,
13531354
"PrintPerf" => AutoDiff::PrintPerf,
1354-
"Print" => AutoDiff::Print,
1355+
"PrintSteps" => AutoDiff::PrintSteps,
13551356
"PrintModBefore" => AutoDiff::PrintModBefore,
1356-
"PrintModAfterOpts" => AutoDiff::PrintModAfterOpts,
1357-
"PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme,
1357+
"PrintModAfter" => AutoDiff::PrintModAfter,
13581358
"LooseTypes" => AutoDiff::LooseTypes,
1359-
"NoModOptAfter" => AutoDiff::NoModOptAfter,
1360-
"EnableFncOpt" => AutoDiff::EnableFncOpt,
1361-
"NoVecUnroll" => AutoDiff::NoVecUnroll,
13621359
"Inline" => AutoDiff::Inline,
13631360
_ => {
13641361
// FIXME(ZuseZ4): print an error saying which value is not recognized
@@ -2081,21 +2078,19 @@ options! {
20812078
assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
20822079
"make cfg(version) treat the current version as incomplete (default: no)"),
20832080
autodiff: Vec<crate::config::AutoDiff> = (Vec::new(), parse_autodiff, [TRACKED],
2084-
"a list of optional autodiff flags to enable
2085-
Optional extra settings:
2086-
`=PrintTA`
2087-
`=PrintAA`
2088-
`=PrintPerf`
2089-
`=Print`
2090-
`=PrintModBefore`
2091-
`=PrintModAfterOpts`
2092-
`=PrintModAfterEnzyme`
2093-
`=LooseTypes`
2094-
`=NoModOptAfter`
2095-
`=EnableFncOpt`
2096-
`=NoVecUnroll`
2097-
`=Inline`
2098-
Multiple options can be combined with commas."),
2081+
"a list of autodiff flags to enable
2082+
Mandatory setting:
2083+
`=Enable`
2084+
Optional extra settings:
2085+
`=PrintTA`
2086+
`=PrintAA`
2087+
`=PrintPerf`
2088+
`=PrintSteps`
2089+
`=PrintModBefore`
2090+
`=PrintModAfter`
2091+
`=LooseTypes`
2092+
`=Inline`
2093+
Multiple options can be combined with commas."),
20992094
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]
21002095
binary_dep_depinfo: bool = (false, parse_bool, [TRACKED],
21012096
"include artifacts (sysroot, crate dependencies) used during compilation in dep-info \

0 commit comments

Comments
 (0)