From 1f30517d40a9a8fe3b89479891c7a167adb75cbd Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 29 Jan 2025 21:31:13 -0500 Subject: [PATCH] upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff --- Cargo.lock | 2 + .../rustc_ast/src/expand/autodiff_attrs.rs | 3 +- .../src/builder/autodiff.rs | 14 +- .../src/coverageinfo/mapgen.rs | 2 +- .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 4 +- compiler/rustc_codegen_ssa/messages.ftl | 2 + compiler/rustc_codegen_ssa/src/back/write.rs | 38 +++++- compiler/rustc_codegen_ssa/src/base.rs | 10 +- .../rustc_codegen_ssa/src/codegen_attrs.rs | 118 ++++++++++++++++- compiler/rustc_codegen_ssa/src/errors.rs | 4 + compiler/rustc_interface/src/tests.rs | 13 +- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 5 +- compiler/rustc_middle/messages.ftl | 4 + compiler/rustc_middle/src/arena.rs | 1 + compiler/rustc_middle/src/error.rs | 14 ++ .../src/middle/codegen_fn_attrs.rs | 4 + compiler/rustc_middle/src/mir/mono.rs | 2 + compiler/rustc_monomorphize/Cargo.toml | 2 + compiler/rustc_monomorphize/src/collector.rs | 2 +- .../rustc_monomorphize/src/partitioning.rs | 32 ++++- .../src/partitioning/autodiff.rs | 121 ++++++++++++++++++ compiler/rustc_session/src/config.rs | 36 +++++- compiler/rustc_session/src/options.rs | 49 +++++++ compiler/rustc_span/src/symbol.rs | 2 - src/bootstrap/src/core/build_steps/compile.rs | 9 +- .../src/compiler-flags/autodiff.md | 23 ++++ src/tools/enzyme | 2 +- 27 files changed, 481 insertions(+), 37 deletions(-) create mode 100644 compiler/rustc_monomorphize/src/partitioning/autodiff.rs create mode 100644 src/doc/unstable-book/src/compiler-flags/autodiff.md diff --git a/Cargo.lock b/Cargo.lock index 979198cece80f..7526cca9c4b29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4234,6 +4234,7 @@ name = "rustc_monomorphize" version = "0.0.0" dependencies = [ "rustc_abi", + "rustc_ast", "rustc_attr_parsing", "rustc_data_structures", "rustc_errors", @@ -4243,6 +4244,7 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", + "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 7ef8bc1797384..ecc522ec39d12 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -79,6 +79,7 @@ pub struct AutoDiffItem { pub target: String, pub attrs: AutoDiffAttrs, } + #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub struct AutoDiffAttrs { /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and @@ -231,7 +232,7 @@ impl AutoDiffAttrs { self.ret_activity == DiffActivity::ActiveOnly } - pub fn error() -> Self { + pub const fn error() -> Self { AutoDiffAttrs { mode: DiffMode::Error, ret_activity: DiffActivity::None, diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 6b17b5f6989be..9e8e4e1c56774 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -62,8 +62,8 @@ fn generate_enzyme_call<'ll>( // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple // functions. Unwrap will only panic, if LLVM gave us an invalid string. let name = llvm::get_value_name(outer_fn); - let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap(); - ad_name.push_str(outer_fn_name.to_string().as_str()); + let outer_fn_name = std::str::from_utf8(name).unwrap(); + ad_name.push_str(outer_fn_name); // Let us assume the user wrote the following function square: // @@ -255,14 +255,14 @@ fn generate_enzyme_call<'ll>( // have no debug info to copy, which would then be ok. trace!("no dbg info"); } + // Now that we copied the metadata, get rid of dummy code. - llvm::LLVMRustEraseInstBefore(entry, last_inst); - llvm::LLVMRustEraseInstFromParent(last_inst); + llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst); - if cx.val_ty(outer_fn) != cx.type_void() { - builder.ret(call); - } else { + if cx.val_ty(call) == cx.type_void() { builder.ret_void(); + } else { + builder.ret(call); } // Let's crash in case that we messed something up above and generated invalid IR. diff --git a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs index fd22421c7fcff..9a2473d6cf23d 100644 --- a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs +++ b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs @@ -298,7 +298,7 @@ struct UsageSets<'tcx> { /// Prepare sets of definitions that are relevant to deciding whether something /// is an "unused function" for coverage purposes. fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> { - let MonoItemPartitions { all_mono_items, codegen_units } = + let MonoItemPartitions { all_mono_items, codegen_units, .. } = tcx.collect_and_partition_mono_items(()); // Obtain a MIR body for each function participating in codegen, via an diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 729d6f62e2431..ae813fe5ebf89 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -7,11 +7,13 @@ use crate::llvm::Bool; extern "C" { // Enzyme pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; - pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value); + pub fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value); pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; pub fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>; pub fn LLVMRustEraseInstFromParent(V: &Value); pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMDumpValue(V: &Value); pub fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool; pub fn LLVMGetFunctionCallConv(F: &Value) -> c_uint; diff --git a/compiler/rustc_codegen_ssa/messages.ftl b/compiler/rustc_codegen_ssa/messages.ftl index de37de09f5a57..4c5664f688886 100644 --- a/compiler/rustc_codegen_ssa/messages.ftl +++ b/compiler/rustc_codegen_ssa/messages.ftl @@ -16,6 +16,8 @@ codegen_ssa_archive_build_failure = failed to build archive at `{$path}`: {$erro codegen_ssa_atomic_compare_exchange = Atomic compare-exchange intrinsic missing failure memory ordering +codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto + codegen_ssa_binary_output_to_tty = option `-o` or `--emit` is used to write binary output type `{$shorthand}` to stdout, but stdout is a tty codegen_ssa_cgu_not_recorded = diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index b40bb4ed5d2ac..914f2c21fa78c 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel}; use std::{fs, io, mem, str, thread}; use rustc_ast::attr; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_data_structures::jobserver::{self, Acquired}; use rustc_data_structures::memmap::Mmap; @@ -40,7 +41,7 @@ use tracing::debug; use super::link::{self, ensure_removed}; use super::lto::{self, SerializedModule}; use super::symbol_export::symbol_name_for_instance_in_crate; -use crate::errors::ErrorCreatingRemarkDir; +use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir}; use crate::traits::*; use crate::{ CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind, @@ -118,6 +119,7 @@ pub struct ModuleConfig { pub merge_functions: bool, pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, + pub autodiff: Vec, } impl ModuleConfig { @@ -266,6 +268,7 @@ impl ModuleConfig { emit_lifetime_markers: sess.emit_lifetime_markers(), llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]), + autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]), } } @@ -389,6 +392,7 @@ impl CodegenContext { fn generate_lto_work( cgcx: &CodegenContext, + autodiff: Vec, needs_fat_lto: Vec>, needs_thin_lto: Vec<(String, B::ThinBuffer)>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, @@ -397,11 +401,19 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let module = + let mut module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); + if cgcx.lto == Lto::Fat { + let config = cgcx.config(ModuleKind::Regular); + module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() }; + } // We are adding a single work item, so the cost doesn't matter. vec![(WorkItem::LTO(module), 0)] } else { + if !autodiff.is_empty() { + let dcx = cgcx.create_dcx(); + dcx.handle().emit_fatal(AutodiffWithoutLto {}); + } assert!(needs_fat_lto.is_empty()); let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) .unwrap_or_else(|e| e.raise()); @@ -1021,6 +1033,9 @@ pub(crate) enum Message { /// Sent from a backend worker thread. WorkItem { result: Result, Option>, worker_id: usize }, + /// A vector containing all the AutoDiff tasks that we have to pass to Enzyme. + AddAutoDiffItems(Vec), + /// The frontend has finished generating something (backend IR or a /// post-LTO artifact) for a codegen unit, and it should be passed to the /// backend. Sent from the main thread. @@ -1348,6 +1363,7 @@ fn start_executing_work( // This is where we collect codegen units that have gone all the way // through codegen and LLVM. + let mut autodiff_items = Vec::new(); let mut compiled_modules = vec![]; let mut compiled_allocator_module = None; let mut needs_link = Vec::new(); @@ -1459,9 +1475,13 @@ fn start_executing_work( let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); - for (work, cost) in - generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules) - { + for (work, cost) in generate_lto_work( + &cgcx, + autodiff_items.clone(), + needs_fat_lto, + needs_thin_lto, + import_only_modules, + ) { let insertion_index = work_items .binary_search_by_key(&cost, |&(_, cost)| cost) .unwrap_or_else(|e| e); @@ -1596,6 +1616,10 @@ fn start_executing_work( main_thread_state = MainThreadState::Idle; } + Message::AddAutoDiffItems(mut items) => { + autodiff_items.append(&mut items); + } + Message::CodegenComplete => { if codegen_state != Aborted { codegen_state = Completed; @@ -2070,6 +2094,10 @@ impl OngoingCodegen { drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::))); } + pub(crate) fn submit_autodiff_items(&self, items: Vec) { + drop(self.coordinator.sender.send(Box::new(Message::::AddAutoDiffItems(items)))); + } + pub(crate) fn check_for_errors(&self, sess: &Session) { self.shared_emitter_main.check(sess, false); } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index e438bd70c510e..8a15acfcad7f2 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -18,7 +18,7 @@ use rustc_middle::middle::debugger_visualizer::{DebuggerVisualizerFile, Debugger use rustc_middle::middle::exported_symbols::SymbolExportKind; use rustc_middle::middle::{exported_symbols, lang_items}; use rustc_middle::mir::BinOp; -use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem}; +use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions}; use rustc_middle::query::Providers; use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout}; use rustc_middle::ty::{self, Instance, Ty, TyCtxt}; @@ -619,7 +619,9 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let codegen_units = tcx.collect_and_partition_mono_items(()).codegen_units; + let MonoItemPartitions { codegen_units, autodiff_items, .. } = + tcx.collect_and_partition_mono_items(()); + let autodiff_fncs = autodiff_items.to_vec(); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -690,6 +692,10 @@ pub fn codegen_crate( ); } + if !autodiff_fncs.is_empty() { + ongoing_codegen.submit_autodiff_items(autodiff_fncs); + } + // For better throughput during parallel processing by LLVM, we used to sort // CGUs largest to smallest. This would lead to better thread utilization // by, for example, preventing a large CGU from being processed last and diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index a0bc2d4ea48f1..4166387dad0c5 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,5 +1,10 @@ +use std::str::FromStr; + use rustc_ast::attr::list_contains_name; -use rustc_ast::{MetaItemInner, attr}; +use rustc_ast::expand::autodiff_attrs::{ + AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, +}; +use rustc_ast::{MetaItem, MetaItemInner, attr}; use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_data_structures::fx::FxHashMap; use rustc_errors::codes::*; @@ -13,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{ }; use rustc_middle::mir::mono::Linkage; use rustc_middle::query::Providers; +use rustc_middle::span_bug; use rustc_middle::ty::{self as ty, TyCtxt}; use rustc_session::parse::feature_err; use rustc_session::{Session, lint}; @@ -65,6 +71,13 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs { codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER; } + // If our rustc version supports autodiff/enzyme, then we call our handler + // to check for any `#[rustc_autodiff(...)]` attributes. + if cfg!(llvm_enzyme) { + let ad = autodiff_attrs(tcx, did.into()); + codegen_fn_attrs.autodiff_item = ad; + } + // When `no_builtins` is applied at the crate level, we should add the // `no-builtins` attribute to each function to ensure it takes effect in LTO. let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID); @@ -856,6 +869,109 @@ impl<'a> MixedExportNameAndNoMangleState<'a> { } } +/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)] +/// macros. There are two forms. The pure one without args to mark primal functions (the functions +/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the +/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never +/// panic, unless we introduced a bug when parsing the autodiff macro. +fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { + let attrs = tcx.get_attrs(id, sym::rustc_autodiff); + + let attrs = + attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::>(); + + // check for exactly one autodiff attribute on placeholder functions. + // There should only be one, since we generate a new placeholder per ad macro. + // FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but + // looks strange e.g. under cargo-expand. + let attr = match &attrs[..] { + [] => return None, + [attr] => attr, + // These two attributes are the same and unfortunately duplicated due to a previous bug. + [attr, _attr2] => attr, + _ => { + //FIXME(ZuseZ4): Once we fixed our parser, we should also prohibit the two-attribute + //branch above. + span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source"); + } + }; + + let list = attr.meta_item_list().unwrap_or_default(); + + // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions + if list.is_empty() { + return Some(AutoDiffAttrs::source()); + } + + let [mode, input_activities @ .., ret_activity] = &list[..] else { + span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities"); + }; + let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode { + p1.segments.first().unwrap().ident + } else { + span_bug!(attr.span, "rustc_autodiff attribute must contain mode"); + }; + + // parse mode + let mode = match mode.as_str() { + "Forward" => DiffMode::Forward, + "Reverse" => DiffMode::Reverse, + "ForwardFirst" => DiffMode::ForwardFirst, + "ReverseFirst" => DiffMode::ReverseFirst, + _ => { + span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode"); + } + }; + + // First read the ret symbol from the attribute + let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity { + p1.segments.first().unwrap().ident + } else { + span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity"); + }; + + // Then parse it into an actual DiffActivity + let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else { + span_bug!(ret_symbol.span, "invalid return activity"); + }; + + // Now parse all the intermediate (input) activities + let mut arg_activities: Vec = vec![]; + for arg in input_activities { + let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg { + match p2.segments.first() { + Some(x) => x.ident, + None => { + span_bug!( + arg.span(), + "rustc_autodiff attribute must contain the input activity" + ); + } + } + } else { + span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity"); + }; + + match DiffActivity::from_str(arg_symbol.as_str()) { + Ok(arg_activity) => arg_activities.push(arg_activity), + Err(_) => { + span_bug!(arg_symbol.span, "invalid input activity"); + } + } + } + + for &input in &arg_activities { + if !valid_input_activity(mode, input) { + span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode); + } + } + if !valid_ret_activity(mode, ret_activity) { + span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode); + } + + Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities }) +} + pub(crate) fn provide(providers: &mut Providers) { *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers }; } diff --git a/compiler/rustc_codegen_ssa/src/errors.rs b/compiler/rustc_codegen_ssa/src/errors.rs index 5e684632fb247..d8e44b7481a77 100644 --- a/compiler/rustc_codegen_ssa/src/errors.rs +++ b/compiler/rustc_codegen_ssa/src/errors.rs @@ -39,6 +39,10 @@ pub(crate) struct CguNotRecorded<'a> { pub cgu_name: &'a str, } +#[derive(Diagnostic)] +#[diag(codegen_ssa_autodiff_without_lto)] +pub struct AutodiffWithoutLto; + #[derive(Diagnostic)] #[diag(codegen_ssa_unknown_reuse_kind)] pub(crate) struct UnknownReuseKind { diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 74d02ac22276c..629be2cd96362 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -8,12 +8,12 @@ use rustc_data_structures::profiling::TimePassesFormat; use rustc_errors::emitter::HumanReadableErrorType; use rustc_errors::{ColorConfig, registry}; use rustc_session::config::{ - BranchProtection, CFGuard, Cfg, CollapseMacroDebuginfo, CoverageLevel, CoverageOptions, - DebugInfo, DumpMonoStatsFormat, ErrorOutputType, ExternEntry, ExternLocation, Externs, - FmtDebug, FunctionReturn, InliningThreshold, Input, InstrumentCoverage, InstrumentXRay, - LinkSelfContained, LinkerPluginLto, LocationDetail, LtoCli, MirIncludeSpans, NextSolverConfig, - OomStrategy, Options, OutFileName, OutputType, OutputTypes, PAuthKey, PacRet, Passes, - PatchableFunctionEntry, Polonius, ProcMacroExecutionStrategy, Strip, SwitchWithOptPath, + AutoDiff, BranchProtection, CFGuard, Cfg, CollapseMacroDebuginfo, CoverageLevel, + CoverageOptions, DebugInfo, DumpMonoStatsFormat, ErrorOutputType, ExternEntry, ExternLocation, + Externs, FmtDebug, FunctionReturn, InliningThreshold, Input, InstrumentCoverage, + InstrumentXRay, LinkSelfContained, LinkerPluginLto, LocationDetail, LtoCli, MirIncludeSpans, + NextSolverConfig, OomStrategy, Options, OutFileName, OutputType, OutputTypes, PAuthKey, PacRet, + Passes, PatchableFunctionEntry, Polonius, ProcMacroExecutionStrategy, Strip, SwitchWithOptPath, SymbolManglingVersion, WasiExecModel, build_configuration, build_session_options, rustc_optgroups, }; @@ -760,6 +760,7 @@ fn test_unstable_options_tracking_hash() { tracked!(allow_features, Some(vec![String::from("lang_items")])); tracked!(always_encode_mir, true); tracked!(assume_incomplete_release, true); + tracked!(autodiff, vec![AutoDiff::Print]); tracked!(binary_dep_depinfo, true); tracked!(box_noalias, false); tracked!( diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 35186778671bc..4cb2376db628c 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -955,7 +955,8 @@ extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) { return nullptr; } -extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) { +extern "C" void LLVMRustEraseInstUntilInclusive(LLVMBasicBlockRef bb, + LLVMValueRef I) { auto &BB = *unwrap(bb); auto &Inst = *unwrap(I); auto It = BB.begin(); @@ -963,8 +964,6 @@ extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) { ++It; // Make sure we found the Instruction. assert(It != BB.end()); - // We don't want to erase the instruction itself. - It--; // Delete in rev order to ensure no dangling references. while (It != BB.begin()) { auto Prev = std::prev(It); diff --git a/compiler/rustc_middle/messages.ftl b/compiler/rustc_middle/messages.ftl index 8dc529b4d7a21..cfb495df9d536 100644 --- a/compiler/rustc_middle/messages.ftl +++ b/compiler/rustc_middle/messages.ftl @@ -32,6 +32,8 @@ middle_assert_shl_overflow = middle_assert_shr_overflow = attempt to shift right by `{$val}`, which would overflow +middle_autodiff_unsafe_inner_const_ref = reading from a `Duplicated` const {$ty} is unsafe + middle_bounds_check = index out of bounds: the length is {$len} but the index is {$index} @@ -107,6 +109,8 @@ middle_type_length_limit = reached the type-length limit while instantiating `{$ middle_unknown_layout = the type `{$ty}` has an unknown layout +middle_unsupported_union = we don't support unions yet: '{$ty_name}' + middle_values_too_big = values of the type `{$ty}` are too big for the target architecture diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 750531b638e4d..eaccd8c360ebd 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -87,6 +87,7 @@ macro_rules! arena_types { [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, [decode] attribute: rustc_hir::Attribute, [] name_set: rustc_data_structures::unord::UnordSet, + [] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, [] pats: rustc_middle::ty::PatternKind<'tcx>, diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs index b0187a1848cf1..b30d3a950c6a3 100644 --- a/compiler/rustc_middle/src/error.rs +++ b/compiler/rustc_middle/src/error.rs @@ -37,6 +37,20 @@ pub struct OpaqueHiddenTypeMismatch<'tcx> { pub sub: TypeMismatchReason, } +#[derive(Diagnostic)] +#[diag(middle_unsupported_union)] +pub struct UnsupportedUnion { + pub ty_name: String, +} + +#[derive(Diagnostic)] +#[diag(middle_autodiff_unsafe_inner_const_ref)] +pub struct AutodiffUnsafeInnerConstRef { + #[primary_span] + pub span: Span, + pub ty: String, +} + #[derive(Subdiagnostic)] pub enum TypeMismatchReason { #[label(middle_conflict_types)] diff --git a/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs b/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs index cc980f6e62aee..1784665bcae8e 100644 --- a/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs +++ b/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs @@ -1,4 +1,5 @@ use rustc_abi::Align; +use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs; use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_macros::{HashStable, TyDecodable, TyEncodable}; use rustc_span::Symbol; @@ -52,6 +53,8 @@ pub struct CodegenFnAttrs { /// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around /// the function entry. pub patchable_function_entry: Option, + /// For the `#[autodiff]` macros. + pub autodiff_item: Option, } #[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)] @@ -160,6 +163,7 @@ impl CodegenFnAttrs { instruction_set: None, alignment: None, patchable_function_entry: None, + autodiff_item: None, } } diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs index 3eccf56d8c4dd..b93046ca8e22f 100644 --- a/compiler/rustc_middle/src/mir/mono.rs +++ b/compiler/rustc_middle/src/mir/mono.rs @@ -1,6 +1,7 @@ use std::fmt; use std::hash::Hash; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_attr_parsing::InlineAttr; use rustc_data_structures::base_n::{BaseNString, CASE_INSENSITIVE, ToBaseN}; use rustc_data_structures::fingerprint::Fingerprint; @@ -251,6 +252,7 @@ impl ToStableHashKey> for MonoItem<'_> { pub struct MonoItemPartitions<'tcx> { pub codegen_units: &'tcx [CodegenUnit<'tcx>], pub all_mono_items: &'tcx DefIdSet, + pub autodiff_items: &'tcx [AutoDiffItem], } #[derive(Debug, HashStable)] diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index 9bdaeb015cd5f..5462105e5e8e4 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] # tidy-alphabetical-start rustc_abi = { path = "../rustc_abi" } +rustc_ast = { path = "../rustc_ast" } rustc_attr_parsing = { path = "../rustc_attr_parsing" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } @@ -15,6 +16,7 @@ rustc_macros = { path = "../rustc_macros" } rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } +rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } rustc_target = { path = "../rustc_target" } serde = "1" serde_json = "1" diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index bb603df112942..f3b5488e9751a 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -257,7 +257,7 @@ struct SharedState<'tcx> { pub(crate) struct UsageMap<'tcx> { // Maps every mono item to the mono items used by it. - used_map: UnordMap, Vec>>, + pub used_map: UnordMap, Vec>>, // Maps every mono item to the mono items that use it. user_map: UnordMap, Vec>>, diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index e08c348a64d74..c985ea04278d4 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -92,6 +92,8 @@ //! source-level module, functions from the same module will be available for //! inlining, even when they are not marked `#[inline]`. +mod autodiff; + use std::cmp; use std::collections::hash_map::Entry; use std::fs::{self, File}; @@ -251,7 +253,17 @@ where can_export_generics, always_export_generics, ); - if visibility == Visibility::Hidden && can_be_internalized { + + // We can't differentiate something that got inlined. + let autodiff_active = cfg!(llvm_enzyme) + && cx + .tcx + .codegen_fn_attrs(mono_item.def_id()) + .autodiff_item + .as_ref() + .is_some_and(|ad| ad.is_active()); + + if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } let size_estimate = mono_item.size_estimate(cx.tcx); @@ -1176,6 +1188,18 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio }) .collect(); + let autodiff_mono_items: Vec<_> = items + .iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance) => Some((item, instance)), + _ => None, + }) + .collect(); + + let autodiff_items = + autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items); + let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); + // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats { if let Err(err) = @@ -1236,7 +1260,11 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio } } - MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units } + MonoItemPartitions { + all_mono_items: tcx.arena.alloc(mono_items), + codegen_units, + autodiff_items, + } } /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs new file mode 100644 index 0000000000000..bce31bf0748e0 --- /dev/null +++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs @@ -0,0 +1,121 @@ +use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity}; +use rustc_hir::def_id::LOCAL_CRATE; +use rustc_middle::bug; +use rustc_middle::mir::mono::MonoItem; +use rustc_middle::ty::{self, Instance, Ty, TyCtxt}; +use rustc_symbol_mangling::symbol_name_for_instance_in_crate; +use tracing::{debug, trace}; + +use crate::partitioning::UsageMap; + +fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec) { + if !matches!(fn_ty.kind(), ty::FnDef(..)) { + bug!("expected fn def for autodiff, got {:?}", fn_ty); + } + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // If rustc compiles the unmodified primal, we know that this copy of the function + // also has correct lifetimes. We know that Enzyme won't free the shadow too early + // (or actually at all), so let's strip lifetimes when computing the layout. + let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); + let mut new_activities = vec![]; + let mut new_positions = vec![]; + for (i, ty) in x.inputs().iter().enumerate() { + if let Some(inner_ty) = ty.builtin_deref(true) { + if ty.is_fn_ptr() { + // FIXME(ZuseZ4): add a nicer error, or just figure out how to support them, + // since Enzyme itself can handle them. + tcx.dcx().err("function pointers are currently not supported in autodiff"); + } + if inner_ty.is_slice() { + // We know that the length will be passed as extra arg. + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + // However, if the slice get's duplicated, we want to know to later check the + // size. So we mark the new size argument as FakeActivitySize. + let activity = match da[i] { + DiffActivity::DualOnly + | DiffActivity::Dual + | DiffActivity::DuplicatedOnly + | DiffActivity::Duplicated => DiffActivity::FakeActivitySize, + DiffActivity::Const => DiffActivity::Const, + _ => bug!("unexpected activity for ptr/ref"), + }; + new_activities.push(activity); + new_positions.push(i + 1); + } + continue; + } + } + } + // now add the extra activities coming from slices + // Reverse order to not invalidate the indices + for _ in 0..new_activities.len() { + let pos = new_positions.pop().unwrap(); + let activity = new_activities.pop().unwrap(); + da.insert(pos, activity); + } +} + +pub(crate) fn find_autodiff_source_functions<'tcx>( + tcx: TyCtxt<'tcx>, + usage_map: &UsageMap<'tcx>, + autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>, +) -> Vec { + let mut autodiff_items: Vec = vec![]; + for (item, instance) in autodiff_mono_items { + let target_id = instance.def_id(); + let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone(); + let Some(target_attrs) = cg_fn_attr else { + continue; + }; + let mut input_activities: Vec = target_attrs.input_activity.clone(); + if target_attrs.is_source() { + trace!("source found: {:?}", target_id); + } + if !target_attrs.apply_autodiff() { + continue; + } + + let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + + let source = + usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item { + MonoItem::Fn(ref instance_s) => { + let source_id = instance_s.def_id(); + if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item + && ad.is_active() + { + return Some(instance_s); + } + None + } + _ => None, + }); + let inst = match source { + Some(source) => source, + None => continue, + }; + + debug!("source_id: {:?}", inst.def_id()); + let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized()); + assert!(fn_ty.is_fn()); + adjust_activity_to_abi(tcx, fn_ty, &mut input_activities); + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); + + let mut new_target_attrs = target_attrs.clone(); + new_target_attrs.input_activity = input_activities; + let itm = new_target_attrs.into_item(symb, target_symbol); + autodiff_items.push(itm); + } + + if !autodiff_items.is_empty() { + trace!("AUTODIFF ITEMS EXIST"); + for item in &mut *autodiff_items { + trace!("{}", &item); + } + } + + autodiff_items +} diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 97bd2670aa640..c8a811985d5bf 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -189,6 +189,39 @@ pub enum CoverageLevel { Mcdc, } +/// The different settings that the `-Z autodiff` flag can have. +#[derive(Clone, Copy, PartialEq, Hash, Debug)] +pub enum AutoDiff { + /// Print TypeAnalysis information + PrintTA, + /// Print ActivityAnalysis Information + PrintAA, + /// Print Performance Warnings from Enzyme + PrintPerf, + /// Combines the three print flags above. + Print, + /// Print the whole module, before running opts. + PrintModBefore, + /// Print the whole module just before we pass it to Enzyme. + /// For Debug purpose, prefer the OPT flag below + PrintModAfterOpts, + /// Print the module after Enzyme differentiated everything. + PrintModAfterEnzyme, + + /// Enzyme's loose type debug helper (can cause incorrect gradients) + LooseTypes, + + /// More flags + NoModOptAfter, + /// Tell Enzyme to run LLVM Opts on each function it generated. By default off, + /// since we already optimize the whole module after Enzyme is done. + EnableFncOpt, + NoVecUnroll, + RuntimeActivity, + /// Runs Enzyme specific Inlining + Inline, +} + /// Settings for `-Z instrument-xray` flag. #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub struct InstrumentXRay { @@ -2902,7 +2935,7 @@ pub(crate) mod dep_tracking { }; use super::{ - BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, + AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn, InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail, LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName, @@ -2950,6 +2983,7 @@ pub(crate) mod dep_tracking { } impl_dep_tracking_hash_via_hash!( + AutoDiff, bool, usize, NonZero, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 63aaa3abc8e56..3b6d1a1201819 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -398,6 +398,7 @@ mod desc { pub(crate) const parse_list: &str = "a space-separated list of strings"; pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; + pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Print`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfterOpts`, `PrintModAfterEnzyme`, `LooseTypes`, `NoModOptAfter`, `EnableFncOpt`, `NoVecUnroll`, `Inline`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -1029,6 +1030,38 @@ pub mod parse { } } + pub(crate) fn parse_autodiff(slot: &mut Vec, v: Option<&str>) -> bool { + let Some(v) = v else { + *slot = vec![]; + return true; + }; + let mut v: Vec<&str> = v.split(",").collect(); + v.sort_unstable(); + for &val in v.iter() { + let variant = match val { + "PrintTA" => AutoDiff::PrintTA, + "PrintAA" => AutoDiff::PrintAA, + "PrintPerf" => AutoDiff::PrintPerf, + "Print" => AutoDiff::Print, + "PrintModBefore" => AutoDiff::PrintModBefore, + "PrintModAfterOpts" => AutoDiff::PrintModAfterOpts, + "PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme, + "LooseTypes" => AutoDiff::LooseTypes, + "NoModOptAfter" => AutoDiff::NoModOptAfter, + "EnableFncOpt" => AutoDiff::EnableFncOpt, + "NoVecUnroll" => AutoDiff::NoVecUnroll, + "Inline" => AutoDiff::Inline, + _ => { + // FIXME(ZuseZ4): print an error saying which value is not recognized + return false; + } + }; + slot.push(variant); + } + + true + } + pub(crate) fn parse_instrument_coverage( slot: &mut InstrumentCoverage, v: Option<&str>, @@ -1736,6 +1769,22 @@ options! { either `loaded` or `not-loaded`."), assume_incomplete_release: bool = (false, parse_bool, [TRACKED], "make cfg(version) treat the current version as incomplete (default: no)"), + autodiff: Vec = (Vec::new(), parse_autodiff, [TRACKED], + "a list of optional autodiff flags to enable + Optional extra settings: + `=PrintTA` + `=PrintAA` + `=PrintPerf` + `=Print` + `=PrintModBefore` + `=PrintModAfterOpts` + `=PrintModAfterEnzyme` + `=LooseTypes` + `=NoModOptAfter` + `=EnableFncOpt` + `=NoVecUnroll` + `=Inline` + Multiple options can be combined with commas."), #[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")] binary_dep_depinfo: bool = (false, parse_bool, [TRACKED], "include artifacts (sysroot, crate dependencies) used during compilation in dep-info \ diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 1fb15fe98000b..cf9f2148c24cc 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -502,7 +502,6 @@ symbols! { augmented_assignments, auto_traits, autodiff, - autodiff_fallback, automatically_derived, avx, avx512_target_feature, @@ -568,7 +567,6 @@ symbols! { cfg_accessible, cfg_attr, cfg_attr_multi, - cfg_autodiff_fallback, cfg_boolean_literals, cfg_doctest, cfg_emscripten_wasm_eh, diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index fd9bf47234c63..f447d186a5242 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1049,9 +1049,12 @@ pub fn rustc_cargo( // . cargo.rustflag("-Zon-broken-pipe=kill"); - if builder.config.llvm_enzyme { - cargo.rustflag("-l").rustflag("Enzyme-19"); - } + // We temporarily disable linking here as part of some refactoring. + // This way, people can manually use -Z llvm-plugins and -C passes=enzyme for now. + // In a follow-up PR, we will re-enable linking here and load the pass for them. + //if builder.config.llvm_enzyme { + // cargo.rustflag("-l").rustflag("Enzyme-19"); + //} // Building with protected visibility reduces the number of dynamic relocations needed, giving // us a faster startup time. However GNU ld < 2.40 will error if we try to link a shared object diff --git a/src/doc/unstable-book/src/compiler-flags/autodiff.md b/src/doc/unstable-book/src/compiler-flags/autodiff.md new file mode 100644 index 0000000000000..4e55be0ad95a0 --- /dev/null +++ b/src/doc/unstable-book/src/compiler-flags/autodiff.md @@ -0,0 +1,23 @@ +# `autodiff` + +The tracking issue for this feature is: [#124509](https://github.com/rust-lang/rust/issues/124509). + +------------------------ + +This feature allows you to differentiate functions using automatic differentiation. +Set the `-Zautodiff=` compiler flag to adjust the behaviour of the autodiff feature. +Multiple options can be separated with a comma. Valid options are: + +`PrintTA` - print Type Analysis Information +`PrintAA` - print Activity Analysis Information +`PrintPerf` - print Performance Warnings from Enzyme +`Print` - prints all intermediate transformations +`PrintModBefore` - print the whole module, before running opts +`PrintModAfterOpts` - print the whole module just before we pass it to Enzyme +`PrintModAfterEnzyme` - print the module after Enzyme differentiated everything +`LooseTypes` - Enzyme's loose type debug helper (can cause incorrect gradients) +`Inline` - runs Enzyme specific Inlining +`NoModOptAfter` - do not optimize the module after Enzyme is done +`EnableFncOpt` - tell Enzyme to run LLVM Opts on each function it generated +`NoVecUnroll` - do not unroll vectorized loops +`RuntimeActivity` - allow specifying activity at runtime diff --git a/src/tools/enzyme b/src/tools/enzyme index 2fe5164a2423d..0e5fa4a3d475f 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 2fe5164a2423dd67ef25e2c4fb204fd06362494b +Subproject commit 0e5fa4a3d475f4dece489c9e06b11164f83789f5