Skip to content

Commit 2382904

Browse files
committed
fix LooseTypes flag and PrintMod behaviour, add debug helper
1 parent a7c39b6 commit 2382904

File tree

6 files changed

+68
-21
lines changed

6 files changed

+68
-21
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -584,12 +584,10 @@ fn thin_lto(
584584
}
585585
}
586586

587-
fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
587+
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
588588
for &val in ad {
589+
// We intentionally don't use a wildcard, to not forget handling anything new.
589590
match val {
590-
config::AutoDiff::PrintModBefore => {
591-
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
592-
}
593591
config::AutoDiff::PrintPerf => {
594592
llvm::set_print_perf(true);
595593
}
@@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
603601
llvm::set_inline(true);
604602
}
605603
config::AutoDiff::LooseTypes => {
606-
llvm::set_loose_types(false);
604+
llvm::set_loose_types(true);
607605
}
608606
config::AutoDiff::PrintSteps => {
609607
llvm::set_print(true);
610608
}
611-
// We handle this below
609+
// We handle this in the PassWrapper.cpp
610+
config::AutoDiff::PrintPasses => {}
611+
// We handle this in the PassWrapper.cpp
612+
config::AutoDiff::PrintModBefore => {}
613+
// We handle this in the PassWrapper.cpp
612614
config::AutoDiff::PrintModAfter => {}
613-
// We handle this below
615+
// We handle this in the PassWrapper.cpp
614616
config::AutoDiff::PrintModFinal => {}
615617
// This is required and already checked
616618
config::AutoDiff::Enable => {}
619+
// We handle this below
620+
config::AutoDiff::NoPostopt => {}
617621
}
618622
}
619623
// This helps with handling enums for now.
@@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
647651
// We then run the llvm_optimize function a second time, to optimize the code which we generated
648652
// in the enzyme differentiation pass.
649653
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
650-
let stage =
651-
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
654+
let stage = if thin {
655+
write::AutodiffStage::PreAD
656+
} else {
657+
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
658+
};
652659

653660
if enable_ad {
654-
enable_autodiff_settings(&config.autodiff, module);
661+
enable_autodiff_settings(&config.autodiff);
655662
}
656663

657664
unsafe {
658665
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
659666
}
660667

661-
if cfg!(llvm_enzyme) && enable_ad {
662-
// This is the post-autodiff IR, mainly used for testing and educational purposes.
663-
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
664-
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
665-
}
666-
668+
if cfg!(llvm_enzyme) && enable_ad && !thin {
667669
let opt_stage = llvm::OptStage::FatLTO;
668670
let stage = write::AutodiffStage::PostAD;
669-
unsafe {
670-
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
671+
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
672+
unsafe {
673+
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
674+
}
671675
}
672676

673677
// This is the final IR, so people should be able to inspect the optimized autodiff output,

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,9 @@ pub(crate) unsafe fn llvm_optimize(
572572

573573
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
574574
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
575+
let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
576+
let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
577+
let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses);
575578
let unroll_loops;
576579
let vectorize_slp;
577580
let vectorize_loop;
@@ -670,6 +673,9 @@ pub(crate) unsafe fn llvm_optimize(
670673
config.no_builtins,
671674
config.emit_lifetime_markers,
672675
run_enzyme,
676+
print_before_enzyme,
677+
print_after_enzyme,
678+
print_passes,
673679
sanitizer_options.as_ref(),
674680
pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
675681
pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,6 +2454,9 @@ unsafe extern "C" {
24542454
DisableSimplifyLibCalls: bool,
24552455
EmitLifetimeMarkers: bool,
24562456
RunEnzyme: bool,
2457+
PrintBeforeEnzyme: bool,
2458+
PrintAfterEnzyme: bool,
2459+
PrintPasses: bool,
24572460
SanitizerOptions: Option<&SanitizerOptions>,
24582461
PGOGenPath: *const c_char,
24592462
PGOUsePath: *const c_char,

compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/IR/LegacyPassManager.h"
1515
#include "llvm/IR/PassManager.h"
1616
#include "llvm/IR/Verifier.h"
17+
#include "llvm/IRPrinter/IRPrintingPasses.h"
1718
#include "llvm/LTO/LTO.h"
1819
#include "llvm/MC/MCSubtargetInfo.h"
1920
#include "llvm/MC/TargetRegistry.h"
@@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
703704
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
704705
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
705706
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
706-
bool EmitLifetimeMarkers, bool RunEnzyme,
707+
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
708+
bool PrintAfterEnzyme, bool PrintPasses,
707709
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
708710
const char *PGOUsePath, bool InstrumentCoverage,
709711
const char *InstrProfileOutput, const char *PGOSampleUsePath,
@@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10481050
// now load "-enzyme" pass:
10491051
#ifdef ENZYME
10501052
if (RunEnzyme) {
1051-
registerEnzymeAndPassPipeline(PB, true);
1053+
1054+
if (PrintBeforeEnzyme) {
1055+
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
1056+
std::string Banner = "Module before EnzymeNewPM";
1057+
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
1058+
}
1059+
1060+
registerEnzymeAndPassPipeline(PB, false);
10521061
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
10531062
std::string ErrMsg = toString(std::move(Err));
10541063
LLVMRustSetLastError(ErrMsg.c_str());
10551064
return LLVMRustResult::Failure;
10561065
}
1066+
1067+
if (PrintAfterEnzyme) {
1068+
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
1069+
std::string Banner = "Module after EnzymeNewPM";
1070+
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
1071+
}
10571072
}
10581073
#endif
1074+
if (PrintPasses) {
1075+
// Print all passes from the PM:
1076+
std::string Pipeline;
1077+
raw_string_ostream SOS(Pipeline);
1078+
MPM.printPipeline(SOS, [&PIC](StringRef ClassName) {
1079+
auto PassName = PIC.getPassNameForClassName(ClassName);
1080+
return PassName.empty() ? ClassName : PassName;
1081+
});
1082+
outs() << Pipeline;
1083+
outs() << "\n";
1084+
}
10591085

10601086
// Upgrade all calls to old intrinsics first.
10611087
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)

compiler/rustc_session/src/config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ pub enum AutoDiff {
246246
/// Print the module after running autodiff and optimizations.
247247
PrintModFinal,
248248

249+
/// Print all passes scheduled by LLVM
250+
PrintPasses,
251+
/// Disable extra opt run after running autodiff
252+
NoPostopt,
249253
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
250254
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
251255
LooseTypes,

compiler/rustc_session/src/options.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ mod desc {
711711
pub(crate) const parse_list: &str = "a space-separated list of strings";
712712
pub(crate) const parse_list_with_polarity: &str =
713713
"a comma-separated list of strings, with elements beginning with + or -";
714-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
714+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
715715
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
716716
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
717717
pub(crate) const parse_number: &str = "a number";
@@ -1360,6 +1360,8 @@ pub mod parse {
13601360
"PrintModBefore" => AutoDiff::PrintModBefore,
13611361
"PrintModAfter" => AutoDiff::PrintModAfter,
13621362
"PrintModFinal" => AutoDiff::PrintModFinal,
1363+
"NoPostopt" => AutoDiff::NoPostopt,
1364+
"PrintPasses" => AutoDiff::PrintPasses,
13631365
"LooseTypes" => AutoDiff::LooseTypes,
13641366
"Inline" => AutoDiff::Inline,
13651367
_ => {
@@ -2098,6 +2100,8 @@ options! {
20982100
`=PrintModBefore`
20992101
`=PrintModAfter`
21002102
`=PrintModFinal`
2103+
`=PrintPasses`,
2104+
`=NoPostopt`
21012105
`=LooseTypes`
21022106
`=Inline`
21032107
Multiple options can be combined with commas."),

0 commit comments

Comments
 (0)