Skip to content

Autodiff flags #139700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,10 @@ fn thin_lto(
}
}

fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
for &val in ad {
// We intentionally don't use a wildcard, to not forget handling anything new.
match val {
config::AutoDiff::PrintModBefore => {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}
config::AutoDiff::PrintPerf => {
llvm::set_print_perf(true);
}
Expand All @@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
llvm::set_inline(true);
}
config::AutoDiff::LooseTypes => {
llvm::set_loose_types(false);
llvm::set_loose_types(true);
}
config::AutoDiff::PrintSteps => {
llvm::set_print(true);
}
// We handle this below
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintPasses => {}
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintModBefore => {}
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintModAfter => {}
// We handle this below
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintModFinal => {}
// This is required and already checked
config::AutoDiff::Enable => {}
// We handle this below
config::AutoDiff::NoPostopt => {}
}
}
// This helps with handling enums for now.
Expand Down Expand Up @@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
// We then run the llvm_optimize function a second time, to optimize the code which we generated
// in the enzyme differentiation pass.
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
let stage =
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
let stage = if thin {
write::AutodiffStage::PreAD
} else {
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
};

if enable_ad {
enable_autodiff_settings(&config.autodiff, module);
enable_autodiff_settings(&config.autodiff);
}

unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}

if cfg!(llvm_enzyme) && enable_ad {
// This is the post-autodiff IR, mainly used for testing and educational purposes.
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
}

if cfg!(llvm_enzyme) && enable_ad && !thin {
let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
unsafe {
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
}
}

// This is the final IR, so people should be able to inspect the optimized autodiff output,
Expand Down
11 changes: 10 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,10 @@ pub(crate) unsafe fn llvm_optimize(

let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses);
let merge_functions;
let unroll_loops;
let vectorize_slp;
let vectorize_loop;
Expand All @@ -573,12 +577,14 @@ pub(crate) unsafe fn llvm_optimize(
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
// We therefore have two calls to llvm_optimize, if autodiff is used.
if consider_ad && autodiff_stage != AutodiffStage::PostAD {
merge_functions = false;
unroll_loops = false;
vectorize_slp = false;
vectorize_loop = false;
} else {
unroll_loops =
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
merge_functions = config.merge_functions;
vectorize_slp = config.vectorize_slp;
vectorize_loop = config.vectorize_loop;
}
Expand Down Expand Up @@ -656,13 +662,16 @@ pub(crate) unsafe fn llvm_optimize(
thin_lto_buffer,
config.emit_thin_lto,
config.emit_thin_lto_summary,
config.merge_functions,
merge_functions,
unroll_loops,
vectorize_slp,
vectorize_loop,
config.no_builtins,
config.emit_lifetime_markers,
run_enzyme,
print_before_enzyme,
print_after_enzyme,
print_passes,
sanitizer_options.as_ref(),
pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,9 @@ unsafe extern "C" {
DisableSimplifyLibCalls: bool,
EmitLifetimeMarkers: bool,
RunEnzyme: bool,
PrintBeforeEnzyme: bool,
PrintAfterEnzyme: bool,
PrintPasses: bool,
SanitizerOptions: Option<&SanitizerOptions>,
PGOGenPath: *const c_char,
PGOUsePath: *const c_char,
Expand Down
30 changes: 28 additions & 2 deletions compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRPrinter/IRPrintingPasses.h"
#include "llvm/LTO/LTO.h"
#include "llvm/MC/MCSubtargetInfo.h"
#include "llvm/MC/TargetRegistry.h"
Expand Down Expand Up @@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
bool EmitLifetimeMarkers, bool RunEnzyme,
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
bool PrintAfterEnzyme, bool PrintPasses,
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
const char *PGOUsePath, bool InstrumentCoverage,
const char *InstrProfileOutput, const char *PGOSampleUsePath,
Expand Down Expand Up @@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize(
// now load "-enzyme" pass:
#ifdef ENZYME
if (RunEnzyme) {
registerEnzymeAndPassPipeline(PB, true);

if (PrintBeforeEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
std::string Banner = "Module before EnzymeNewPM";
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
}

registerEnzymeAndPassPipeline(PB, false);
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
std::string ErrMsg = toString(std::move(Err));
LLVMRustSetLastError(ErrMsg.c_str());
return LLVMRustResult::Failure;
}

if (PrintAfterEnzyme) {
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
std::string Banner = "Module after EnzymeNewPM";
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
}
}
#endif
if (PrintPasses) {
// Print all passes from the PM:
std::string Pipeline;
raw_string_ostream SOS(Pipeline);
MPM.printPipeline(SOS, [&PIC](StringRef ClassName) {
auto PassName = PIC.getPassNameForClassName(ClassName);
return PassName.empty() ? ClassName : PassName;
});
outs() << Pipeline;
outs() << "\n";
}

// Upgrade all calls to old intrinsics first.
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_session/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ pub enum AutoDiff {
/// Print the module after running autodiff and optimizations.
PrintModFinal,

/// Print all passes scheduled by LLVM
PrintPasses,
/// Disable extra opt run after running autodiff
NoPostopt,
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
LooseTypes,
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_session/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ mod desc {
pub(crate) const parse_list: &str = "a space-separated list of strings";
pub(crate) const parse_list_with_polarity: &str =
"a comma-separated list of strings, with elements beginning with + or -";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
pub(crate) const parse_number: &str = "a number";
Expand Down Expand Up @@ -1360,6 +1360,8 @@ pub mod parse {
"PrintModBefore" => AutoDiff::PrintModBefore,
"PrintModAfter" => AutoDiff::PrintModAfter,
"PrintModFinal" => AutoDiff::PrintModFinal,
"NoPostopt" => AutoDiff::NoPostopt,
"PrintPasses" => AutoDiff::PrintPasses,
"LooseTypes" => AutoDiff::LooseTypes,
"Inline" => AutoDiff::Inline,
_ => {
Expand Down Expand Up @@ -2095,6 +2097,8 @@ options! {
`=PrintModBefore`
`=PrintModAfter`
`=PrintModFinal`
`=PrintPasses`,
`=NoPostopt`
`=LooseTypes`
`=Inline`
Multiple options can be combined with commas."),
Expand Down