Skip to content

Commit 1221cff

Browse files
committed
move second opt run to lto phase and cleanup code
1 parent 21d0961 commit 1221cff

File tree

7 files changed

+75
-54
lines changed

7 files changed

+75
-54
lines changed

Diff for: compiler/rustc_codegen_llvm/src/back/lto.rs

+23-2
Original file line numberDiff line numberDiff line change
@@ -606,10 +606,31 @@ pub(crate) fn run_pass_manager(
606606

607607
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
608608
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
609-
let first_run = true;
610609
debug!("running llvm pm opt pipeline");
611610
unsafe {
612-
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
611+
write::llvm_optimize(
612+
cgcx,
613+
dcx,
614+
module,
615+
config,
616+
opt_level,
617+
opt_stage,
618+
write::AutodiffStage::DuringAD,
619+
)?;
620+
}
621+
// FIXME(ZuseZ4): Make this more granular
622+
if cfg!(llvm_enzyme) && !thin {
623+
unsafe {
624+
write::llvm_optimize(
625+
cgcx,
626+
dcx,
627+
module,
628+
config,
629+
opt_level,
630+
llvm::OptStage::FatLTO,
631+
write::AutodiffStage::PostAD,
632+
)?;
633+
}
613634
}
614635
debug!("lto done");
615636
Ok(())

Diff for: compiler/rustc_codegen_llvm/src/back/write.rs

+25-16
Original file line numberDiff line numberDiff line change
@@ -530,14 +530,24 @@ fn get_instr_profile_output_path(config: &ModuleConfig) -> Option<CString> {
530530
config.instrument_coverage.then(|| c"default_%m_%p.profraw".to_owned())
531531
}
532532

533+
// PreAD will run llvm opts but disable size increasing opts (vectorization, loop unrolling)
534+
// DuringAD is the same as above, but also runs the enzyme opt and autodiff passes.
535+
// PostAD will run all opts, including size increasing opts.
536+
#[derive(Debug, Eq, PartialEq)]
537+
pub(crate) enum AutodiffStage {
538+
PreAD,
539+
DuringAD,
540+
PostAD,
541+
}
542+
533543
pub(crate) unsafe fn llvm_optimize(
534544
cgcx: &CodegenContext<LlvmCodegenBackend>,
535545
dcx: DiagCtxtHandle<'_>,
536546
module: &ModuleCodegen<ModuleLlvm>,
537547
config: &ModuleConfig,
538548
opt_level: config::OptLevel,
539549
opt_stage: llvm::OptStage,
540-
skip_size_increasing_opts: bool,
550+
autodiff_stage: AutodiffStage,
541551
) -> Result<(), FatalError> {
542552
// Enzyme:
543553
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
@@ -550,13 +560,16 @@ pub(crate) unsafe fn llvm_optimize(
550560
let unroll_loops;
551561
let vectorize_slp;
552562
let vectorize_loop;
563+
let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD;
553564

554-
let run_enzyme = cfg!(llvm_enzyme);
555565
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
556-
// optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
566+
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
567+
// We therefore have two calls to llvm_optimize, if autodiff is used.
568+
//
569+
// FIXME(ZuseZ4): Before shipping on nightly,
557570
// we should make this more granular, or at least check that the user has at least one autodiff
558571
// call in their code, to justify altering the compilation pipeline.
559-
if skip_size_increasing_opts && run_enzyme {
572+
if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD {
560573
unroll_loops = false;
561574
vectorize_slp = false;
562575
vectorize_loop = false;
@@ -566,7 +579,7 @@ pub(crate) unsafe fn llvm_optimize(
566579
vectorize_slp = config.vectorize_slp;
567580
vectorize_loop = config.vectorize_loop;
568581
}
569-
trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop);
582+
trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop, ?run_enzyme);
570583
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
571584
let pgo_gen_path = get_pgo_gen_path(config);
572585
let pgo_use_path = get_pgo_use_path(config);
@@ -686,18 +699,14 @@ pub(crate) unsafe fn optimize(
686699
_ => llvm::OptStage::PreLinkNoLTO,
687700
};
688701

689-
// If we know that we will later run AD, then we disable vectorization and loop unrolling
690-
let skip_size_increasing_opts = cfg!(llvm_enzyme);
702+
// If we know that we will later run AD, then we disable vectorization and loop unrolling.
703+
// Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
704+
// FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff
705+
// usages, not just if we build rustc with autodiff support.
706+
let autodiff_stage =
707+
if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
691708
return unsafe {
692-
llvm_optimize(
693-
cgcx,
694-
dcx,
695-
module,
696-
config,
697-
opt_level,
698-
opt_stage,
699-
skip_size_increasing_opts,
700-
)
709+
llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage)
701710
};
702711
}
703712
Ok(())

Diff for: compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+7-28
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::back::write::ModuleConfig;
66
use rustc_errors::FatalError;
7-
use rustc_session::config::Lto;
87
use tracing::{debug, trace};
98

10-
use crate::back::write::{llvm_err, llvm_optimize};
9+
use crate::back::write::llvm_err;
1110
use crate::builder::SBuilder;
1211
use crate::context::SimpleCx;
1312
use crate::declare::declare_simple_fn;
@@ -153,7 +152,7 @@ fn generate_enzyme_call<'ll>(
153152
_ => {}
154153
}
155154

156-
trace!("matching autodiff arguments");
155+
debug!("matching autodiff arguments");
157156
// We now handle the issue that Rust level arguments not always match the llvm-ir level
158157
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
159158
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
@@ -222,7 +221,10 @@ fn generate_enzyme_call<'ll>(
222221
// A duplicated pointer will have the following two outer_fn arguments:
223222
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
224223
// (..., metadata! enzyme_dup, ptr, ptr, ...).
225-
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
224+
if matches!(
225+
diff_activity,
226+
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly
227+
) {
226228
assert!(
227229
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
228230
);
@@ -282,7 +284,7 @@ pub(crate) fn differentiate<'ll>(
282284
module: &'ll ModuleCodegen<ModuleLlvm>,
283285
cgcx: &CodegenContext<LlvmCodegenBackend>,
284286
diff_items: Vec<AutoDiffItem>,
285-
config: &ModuleConfig,
287+
_config: &ModuleConfig,
286288
) -> Result<(), FatalError> {
287289
for item in &diff_items {
288290
trace!("{}", item);
@@ -317,29 +319,6 @@ pub(crate) fn differentiate<'ll>(
317319

318320
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
319321

320-
if let Some(opt_level) = config.opt_level {
321-
let opt_stage = match cgcx.lto {
322-
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
323-
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
324-
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
325-
_ => llvm::OptStage::PreLinkNoLTO,
326-
};
327-
// This is our second opt call, so now we run all opts,
328-
// to make sure we get the best performance.
329-
let skip_size_increasing_opts = false;
330-
trace!("running Module Optimization after differentiation");
331-
unsafe {
332-
llvm_optimize(
333-
cgcx,
334-
diag_handler.handle(),
335-
module,
336-
config,
337-
opt_level,
338-
opt_stage,
339-
skip_size_increasing_opts,
340-
)?
341-
};
342-
}
343322
trace!("done with differentiate()");
344323

345324
Ok(())

Diff for: compiler/rustc_llvm/build.rs

+4
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ fn main() {
193193
cfg.define(&flag, None);
194194
}
195195

196+
if tracked_env_var_os("LLVM_ENZYME").is_some() {
197+
cfg.define("ENZYME", None);
198+
}
199+
196200
if tracked_env_var_os("LLVM_RUSTLLVM").is_some() {
197201
cfg.define("LLVM_RUSTLLVM", None);
198202
}

Diff for: compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -689,16 +689,19 @@ struct LLVMRustSanitizerOptions {
689689
};
690690

691691
// This symbol won't be available or used when Enzyme is not enabled
692-
extern "C" void registerEnzyme(llvm::PassBuilder &PB) __attribute__((weak));
692+
#ifdef ENZYME
693+
extern "C" void registerEnzyme(llvm::PassBuilder &PB);
694+
#endif
693695

694696
extern "C" LLVMRustResult LLVMRustOptimize(
695697
LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
696698
LLVMRustPassBuilderOptLevel OptLevelRust, LLVMRustOptStage OptStage,
697699
bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR,
698700
bool LintIR, bool UseThinLTOBuffers, bool MergeFunctions, bool UnrollLoops,
699701
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
700-
bool EmitLifetimeMarkers, bool RunEnzyme, LLVMRustSanitizerOptions *SanitizerOptions,
701-
const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage,
702+
bool EmitLifetimeMarkers, bool RunEnzyme,
703+
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
704+
const char *PGOUsePath, bool InstrumentCoverage,
702705
const char *InstrProfileOutput, const char *PGOSampleUsePath,
703706
bool DebugInfoForProfiling, void *LlvmSelfProfiler,
704707
LLVMRustSelfProfileBeforePassCallback BeforePassCallback,
@@ -1014,6 +1017,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10141017
}
10151018

10161019
// now load "-enzyme" pass:
1020+
#ifdef ENZYME
10171021
if (RunEnzyme) {
10181022
registerEnzyme(PB);
10191023
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
@@ -1022,6 +1026,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10221026
return LLVMRustResult::Failure;
10231027
}
10241028
}
1029+
#endif
10251030

10261031
// Upgrade all calls to old intrinsics first.
10271032
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)

Diff for: src/bootstrap/src/core/build_steps/compile.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -1049,9 +1049,9 @@ pub fn rustc_cargo(
10491049
// <https://rust-lang.zulipchat.com/#narrow/stream/131828-t-compiler/topic/Internal.20lint.20for.20raw.20.60print!.60.20and.20.60println!.60.3F>.
10501050
cargo.rustflag("-Zon-broken-pipe=kill");
10511051

1052-
// We temporarily disable linking here as part of some refactoring.
1053-
// This way, people can manually use -Z llvm-plugins and -C passes=enzyme for now.
1054-
// In a follow-up PR, we will re-enable linking here and load the pass for them.
1052+
// We want to link against registerEnzyme and in the future we want to use additional
1053+
// functionality from Enzyme core. For that we need to link against Enzyme.
1054+
// FIXME(ZuseZ4): Get the LLVM version number automatically instead of hardcoding it.
10551055
if builder.config.llvm_enzyme {
10561056
cargo.rustflag("-l").rustflag("Enzyme-19");
10571057
}
@@ -1234,6 +1234,9 @@ fn rustc_llvm_env(builder: &Builder<'_>, cargo: &mut Cargo, target: TargetSelect
12341234
if builder.is_rust_llvm(target) {
12351235
cargo.env("LLVM_RUSTLLVM", "1");
12361236
}
1237+
if builder.config.llvm_enzyme {
1238+
cargo.env("LLVM_ENZYME", "1");
1239+
}
12371240
let llvm::LlvmResult { llvm_config, .. } = builder.ensure(llvm::Llvm { target });
12381241
cargo.env("LLVM_CONFIG", &llvm_config);
12391242

Diff for: tests/codegen/autodiff.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ fn square(x: &f64) -> f64 {
1515
// CHECK-NEXT:invertstart:
1616
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
1717
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
18-
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8, !alias.scope !17816, !noalias !17819
18+
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8
1919
// CHECK-NEXT: %2 = fadd fast double %1, %0
20-
// CHECK-NEXT: store double %2, ptr %"x'", align 8, !alias.scope !17816, !noalias !17819
20+
// CHECK-NEXT: store double %2, ptr %"x'", align 8
2121
// CHECK-NEXT: ret double %_0
2222
// CHECK-NEXT:}
2323

0 commit comments

Comments
 (0)