Skip to content

Separate Builder methods from tcx #135581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion compiler/rustc_codegen_gcc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ impl WriteBackendMethods for GccCodegenBackend {
}
fn autodiff(
_cgcx: &CodegenContext<Self>,
_tcx: TyCtxt<'_>,
_module: &ModuleCodegen<Self::Module>,
_diff_fncs: Vec<AutoDiffItem>,
_config: &ModuleConfig,
Expand Down
150 changes: 138 additions & 12 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::borrow::Cow;
use std::borrow::{Borrow, Cow};
use std::ops::Deref;
use std::{iter, ptr};

Expand Down Expand Up @@ -31,27 +31,135 @@ use tracing::{debug, instrument};
use crate::abi::FnAbiLlvmExt;
use crate::attributes;
use crate::common::Funclet;
use crate::context::CodegenCx;
use crate::context::{CodegenCx, SimpleCx};
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
use crate::value::Value;

// All Builders must have an llfn associated with them
#[must_use]
pub(crate) struct Builder<'a, 'll, 'tcx> {
pub(crate) struct GenericBuilder<'a, 'll, CX: Borrow<SimpleCx<'ll>>> {
pub llbuilder: &'ll mut llvm::Builder<'ll>,
pub cx: &'a CodegenCx<'ll, 'tcx>,
pub cx: &'a CX,
}

impl Drop for Builder<'_, '_, '_> {
pub(crate) type SBuilder<'a, 'll> = GenericBuilder<'a, 'll, SimpleCx<'ll>>;
pub(crate) type Builder<'a, 'll, 'tcx> = GenericBuilder<'a, 'll, CodegenCx<'ll, 'tcx>>;

impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> Drop for GenericBuilder<'a, 'll, CX> {
fn drop(&mut self) {
unsafe {
llvm::LLVMDisposeBuilder(&mut *(self.llbuilder as *mut _));
}
}
}

impl<'a, 'll> SBuilder<'a, 'll> {
fn call(
&mut self,
llty: &'ll Type,
llfn: &'ll Value,
args: &[&'ll Value],
funclet: Option<&Funclet<'ll>>,
) -> &'ll Value {
debug!("call {:?} with args ({:?})", llfn, args);

let args = self.check_call("call", llty, llfn, args);
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
if let Some(funclet_bundle) = funclet_bundle {
bundles.push(funclet_bundle);
}

let call = unsafe {
llvm::LLVMBuildCallWithOperandBundles(
self.llbuilder,
llty,
llfn,
args.as_ptr() as *const &llvm::Value,
args.len() as c_uint,
bundles.as_ptr(),
bundles.len() as c_uint,
c"".as_ptr(),
)
};
call
}

fn with_scx(scx: &'a SimpleCx<'ll>) -> Self {
// Create a fresh builder from the simple context.
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(scx.llcx) };
SBuilder { llbuilder, cx: scx }
}
}
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
pub(crate) fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
}

fn ret_void(&mut self) {
unsafe {
llvm::LLVMBuildRetVoid(self.llbuilder);
}
}

fn ret(&mut self, v: &'ll Value) {
unsafe {
llvm::LLVMBuildRet(self.llbuilder, v);
}
}
}
impl<'a, 'll> SBuilder<'a, 'll> {
fn build(cx: &'a SimpleCx<'ll>, llbb: &'ll BasicBlock) -> SBuilder<'a, 'll> {
let bx = SBuilder::with_scx(cx);
unsafe {
llvm::LLVMPositionBuilderAtEnd(bx.llbuilder, llbb);
}
bx
}

fn check_call<'b>(
&mut self,
typ: &str,
fn_ty: &'ll Type,
llfn: &'ll Value,
args: &'b [&'ll Value],
) -> Cow<'b, [&'ll Value]> {
assert!(
self.cx.type_kind(fn_ty) == TypeKind::Function,
"builder::{typ} not passed a function, but {fn_ty:?}"
);

let param_tys = self.cx.func_params_types(fn_ty);

let all_args_match = iter::zip(&param_tys, args.iter().map(|&v| self.cx.val_ty(v)))
.all(|(expected_ty, actual_ty)| *expected_ty == actual_ty);

if all_args_match {
return Cow::Borrowed(args);
}

let casted_args: Vec<_> = iter::zip(param_tys, args)
.enumerate()
.map(|(i, (expected_ty, &actual_val))| {
let actual_ty = self.cx.val_ty(actual_val);
if expected_ty != actual_ty {
debug!(
"type mismatch in function call of {:?}. \
Expected {:?} for param {}, got {:?}; injecting bitcast",
llfn, expected_ty, i, actual_ty
);
self.bitcast(actual_val, expected_ty)
} else {
actual_val
}
})
.collect();

Cow::Owned(casted_args)
}
}

/// Empty string, to be used where LLVM expects an instruction name, indicating
/// that the instruction is to be left unnamed (i.e. numbered, in textual IR).
// FIXME(eddyb) pass `&CStr` directly to FFI once it's a thin pointer.
Expand Down Expand Up @@ -1222,6 +1330,14 @@ impl<'ll> StaticBuilderMethods for Builder<'_, 'll, '_> {
}

impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
fn build(cx: &'a CodegenCx<'ll, 'tcx>, llbb: &'ll BasicBlock) -> Builder<'a, 'll, 'tcx> {
let bx = Builder::with_cx(cx);
unsafe {
llvm::LLVMPositionBuilderAtEnd(bx.llbuilder, llbb);
}
bx
}

fn with_cx(cx: &'a CodegenCx<'ll, 'tcx>) -> Self {
// Create a fresh builder from the crate context.
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(cx.llcx) };
Expand All @@ -1231,13 +1347,16 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
pub(crate) fn llfn(&self) -> &'ll Value {
unsafe { llvm::LLVMGetBasicBlockParent(self.llbb()) }
}
}

impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
fn position_at_start(&mut self, llbb: &'ll BasicBlock) {
unsafe {
llvm::LLVMRustPositionBuilderAtStart(self.llbuilder, llbb);
}
}

}
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
fn align_metadata(&mut self, load: &'ll Value, align: Align) {
unsafe {
let md = [llvm::LLVMValueAsMetadata(self.cx.const_u64(align.bytes()))];
Expand All @@ -1259,7 +1378,8 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
self.set_metadata(inst, llvm::MD_unpredictable, md);
}
}

}
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
pub(crate) fn minnum(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
unsafe { llvm::LLVMRustBuildMinNum(self.llbuilder, lhs, rhs) }
}
Expand Down Expand Up @@ -1360,7 +1480,9 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
let ret = unsafe { llvm::LLVMBuildCatchRet(self.llbuilder, funclet.cleanuppad(), unwind) };
ret.expect("LLVM does not have support for catchret")
}
}

impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
fn check_call<'b>(
&mut self,
typ: &str,
Expand Down Expand Up @@ -1401,11 +1523,13 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {

Cow::Owned(casted_args)
}

}
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
pub(crate) fn va_arg(&mut self, list: &'ll Value, ty: &'ll Type) -> &'ll Value {
unsafe { llvm::LLVMBuildVAArg(self.llbuilder, list, ty, UNNAMED) }
}

}
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
pub(crate) fn call_intrinsic(&mut self, intrinsic: &str, args: &[&'ll Value]) -> &'ll Value {
let (ty, f) = self.cx.get_intrinsic(intrinsic);
self.call(ty, None, None, f, args, None, None)
Expand All @@ -1423,7 +1547,8 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {

self.call_intrinsic(intrinsic, &[self.cx.const_u64(size), ptr]);
}

}
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
pub(crate) fn phi(
&mut self,
ty: &'ll Type,
Expand All @@ -1443,7 +1568,8 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
llvm::LLVMAddIncoming(phi, &val, &bb, 1 as c_uint);
}
}

}
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
fn fptoint_sat(&mut self, signed: bool, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
let src_ty = self.cx.val_ty(val);
let (float_ty, int_ty, vector_length) = if self.cx.type_kind(src_ty) == TypeKind::Vector {
Expand Down
25 changes: 11 additions & 14 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@ use std::ptr;
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
use rustc_codegen_ssa::ModuleCodegen;
use rustc_codegen_ssa::back::write::ModuleConfig;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_errors::FatalError;
use rustc_middle::ty::TyCtxt;
use rustc_session::config::Lto;
use tracing::{debug, trace};

use crate::back::write::{llvm_err, llvm_optimize};
use crate::builder::Builder;
use crate::declare::declare_raw_fn;
use crate::builder::SBuilder;
use crate::context::SimpleCx;
use crate::declare::declare_simple_fn;
use crate::errors::LlvmError;
use crate::llvm::AttributePlace::Function;
use crate::llvm::{Metadata, True};
use crate::value::Value;
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm};
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};

fn get_params(fnc: &Value) -> Vec<&Value> {
unsafe {
Expand All @@ -38,8 +37,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
fn generate_enzyme_call<'ll, 'tcx>(
cx: &context::CodegenCx<'ll, 'tcx>,
fn generate_enzyme_call<'ll>(
cx: &SimpleCx<'ll>,
fn_to_diff: &'ll Value,
outer_fn: &'ll Value,
attrs: AutoDiffAttrs,
Expand Down Expand Up @@ -112,7 +111,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
// think a bit more about what should go here.
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
let ad_fn = declare_raw_fn(
let ad_fn = declare_simple_fn(
cx,
&ad_name,
llvm::CallConv::try_from(cc).expect("invalid callconv"),
Expand All @@ -132,7 +131,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
llvm::LLVMRustEraseInstFromParent(br);

let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
let mut builder = Builder::build(cx, entry);
let mut builder = SBuilder::build(cx, entry);

let num_args = llvm::LLVMCountParams(&fn_to_diff);
let mut args = Vec::with_capacity(num_args as usize + 1);
Expand Down Expand Up @@ -236,7 +235,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
}
}

let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
let call = builder.call(enzyme_ty, ad_fn, &args, None);

// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
// metadata attachted to it, but we just created this code oota. Given that the
Expand Down Expand Up @@ -274,10 +273,9 @@ fn generate_enzyme_call<'ll, 'tcx>(
}
}

pub(crate) fn differentiate<'ll, 'tcx>(
pub(crate) fn differentiate<'ll>(
module: &'ll ModuleCodegen<ModuleLlvm>,
cgcx: &CodegenContext<LlvmCodegenBackend>,
tcx: TyCtxt<'tcx>,
diff_items: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError> {
Expand All @@ -286,8 +284,7 @@ pub(crate) fn differentiate<'ll, 'tcx>(
}

let diag_handler = cgcx.create_dcx();
let (_, cgus) = tcx.collect_and_partition_mono_items(());
let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };

// Before dumping the module, we want all the TypeTrees to become part of the module.
for item in diff_items.iter() {
Expand Down
Loading
Loading