Skip to content

Commit 1e454fe

Browse files
authored
Rollup merge of #135581 - EnzymeAD:refactor-codgencx, r=oli-obk
Separate Builder methods from tcx As part of the autodiff upstreaming we noticed, that it would be nice to have various builder methods available without the TypeContext, which prevents the normal CodegenCx to be passed around between threads. We introduce a SimpleCx which just owns the llvm module and llvm context, to encapsulate them. The previous CodegenCx now implements deref and forwards access to the llvm module or context to it's SimpleCx sub-struct. This gives us a bit more flexibility, because now we can pass (or construct) the SimpleCx in locations where we don't have enough information to construct a CodegenCx, or are not able to pass it around due to the tcx lifetimes (and it not implementing send/sync). This also introduces an SBuilder, similar to the SimpleCx. The SBuilder uses a SimpleCx, whereas the existing Builder uses the larger CodegenCx. I will push updates to make implementations generic (where possible) to be implemented once and work for either of the two. I'll also clean up the leftover code. `call` is a bit tricky, because it requires a tcx, I probably need to duplicate it after all. Tracking: - #124509
2 parents 0741cc0 + 386c233 commit 1e454fe

File tree

10 files changed

+239
-56
lines changed

10 files changed

+239
-56
lines changed

Diff for: compiler/rustc_codegen_gcc/src/lib.rs

-1
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ impl WriteBackendMethods for GccCodegenBackend {
444444
}
445445
fn autodiff(
446446
_cgcx: &CodegenContext<Self>,
447-
_tcx: TyCtxt<'_>,
448447
_module: &ModuleCodegen<Self::Module>,
449448
_diff_fncs: Vec<AutoDiffItem>,
450449
_config: &ModuleConfig,

Diff for: compiler/rustc_codegen_llvm/src/builder.rs

+138-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::borrow::Cow;
1+
use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

@@ -31,27 +31,135 @@ use tracing::{debug, instrument};
3131
use crate::abi::FnAbiLlvmExt;
3232
use crate::attributes;
3333
use crate::common::Funclet;
34-
use crate::context::CodegenCx;
34+
use crate::context::{CodegenCx, SimpleCx};
3535
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True};
3636
use crate::type_::Type;
3737
use crate::type_of::LayoutLlvmExt;
3838
use crate::value::Value;
3939

40-
// All Builders must have an llfn associated with them
4140
#[must_use]
42-
pub(crate) struct Builder<'a, 'll, 'tcx> {
41+
pub(crate) struct GenericBuilder<'a, 'll, CX: Borrow<SimpleCx<'ll>>> {
4342
pub llbuilder: &'ll mut llvm::Builder<'ll>,
44-
pub cx: &'a CodegenCx<'ll, 'tcx>,
43+
pub cx: &'a CX,
4544
}
4645

47-
impl Drop for Builder<'_, '_, '_> {
46+
pub(crate) type SBuilder<'a, 'll> = GenericBuilder<'a, 'll, SimpleCx<'ll>>;
47+
pub(crate) type Builder<'a, 'll, 'tcx> = GenericBuilder<'a, 'll, CodegenCx<'ll, 'tcx>>;
48+
49+
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> Drop for GenericBuilder<'a, 'll, CX> {
4850
fn drop(&mut self) {
4951
unsafe {
5052
llvm::LLVMDisposeBuilder(&mut *(self.llbuilder as *mut _));
5153
}
5254
}
5355
}
5456

57+
impl<'a, 'll> SBuilder<'a, 'll> {
58+
fn call(
59+
&mut self,
60+
llty: &'ll Type,
61+
llfn: &'ll Value,
62+
args: &[&'ll Value],
63+
funclet: Option<&Funclet<'ll>>,
64+
) -> &'ll Value {
65+
debug!("call {:?} with args ({:?})", llfn, args);
66+
67+
let args = self.check_call("call", llty, llfn, args);
68+
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
69+
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
70+
if let Some(funclet_bundle) = funclet_bundle {
71+
bundles.push(funclet_bundle);
72+
}
73+
74+
let call = unsafe {
75+
llvm::LLVMBuildCallWithOperandBundles(
76+
self.llbuilder,
77+
llty,
78+
llfn,
79+
args.as_ptr() as *const &llvm::Value,
80+
args.len() as c_uint,
81+
bundles.as_ptr(),
82+
bundles.len() as c_uint,
83+
c"".as_ptr(),
84+
)
85+
};
86+
call
87+
}
88+
89+
fn with_scx(scx: &'a SimpleCx<'ll>) -> Self {
90+
// Create a fresh builder from the simple context.
91+
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(scx.llcx) };
92+
SBuilder { llbuilder, cx: scx }
93+
}
94+
}
95+
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
96+
pub(crate) fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
97+
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
98+
}
99+
100+
fn ret_void(&mut self) {
101+
unsafe {
102+
llvm::LLVMBuildRetVoid(self.llbuilder);
103+
}
104+
}
105+
106+
fn ret(&mut self, v: &'ll Value) {
107+
unsafe {
108+
llvm::LLVMBuildRet(self.llbuilder, v);
109+
}
110+
}
111+
}
112+
impl<'a, 'll> SBuilder<'a, 'll> {
113+
fn build(cx: &'a SimpleCx<'ll>, llbb: &'ll BasicBlock) -> SBuilder<'a, 'll> {
114+
let bx = SBuilder::with_scx(cx);
115+
unsafe {
116+
llvm::LLVMPositionBuilderAtEnd(bx.llbuilder, llbb);
117+
}
118+
bx
119+
}
120+
121+
fn check_call<'b>(
122+
&mut self,
123+
typ: &str,
124+
fn_ty: &'ll Type,
125+
llfn: &'ll Value,
126+
args: &'b [&'ll Value],
127+
) -> Cow<'b, [&'ll Value]> {
128+
assert!(
129+
self.cx.type_kind(fn_ty) == TypeKind::Function,
130+
"builder::{typ} not passed a function, but {fn_ty:?}"
131+
);
132+
133+
let param_tys = self.cx.func_params_types(fn_ty);
134+
135+
let all_args_match = iter::zip(&param_tys, args.iter().map(|&v| self.cx.val_ty(v)))
136+
.all(|(expected_ty, actual_ty)| *expected_ty == actual_ty);
137+
138+
if all_args_match {
139+
return Cow::Borrowed(args);
140+
}
141+
142+
let casted_args: Vec<_> = iter::zip(param_tys, args)
143+
.enumerate()
144+
.map(|(i, (expected_ty, &actual_val))| {
145+
let actual_ty = self.cx.val_ty(actual_val);
146+
if expected_ty != actual_ty {
147+
debug!(
148+
"type mismatch in function call of {:?}. \
149+
Expected {:?} for param {}, got {:?}; injecting bitcast",
150+
llfn, expected_ty, i, actual_ty
151+
);
152+
self.bitcast(actual_val, expected_ty)
153+
} else {
154+
actual_val
155+
}
156+
})
157+
.collect();
158+
159+
Cow::Owned(casted_args)
160+
}
161+
}
162+
55163
/// Empty string, to be used where LLVM expects an instruction name, indicating
56164
/// that the instruction is to be left unnamed (i.e. numbered, in textual IR).
57165
// FIXME(eddyb) pass `&CStr` directly to FFI once it's a thin pointer.
@@ -1222,6 +1330,14 @@ impl<'ll> StaticBuilderMethods for Builder<'_, 'll, '_> {
12221330
}
12231331

12241332
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
1333+
fn build(cx: &'a CodegenCx<'ll, 'tcx>, llbb: &'ll BasicBlock) -> Builder<'a, 'll, 'tcx> {
1334+
let bx = Builder::with_cx(cx);
1335+
unsafe {
1336+
llvm::LLVMPositionBuilderAtEnd(bx.llbuilder, llbb);
1337+
}
1338+
bx
1339+
}
1340+
12251341
fn with_cx(cx: &'a CodegenCx<'ll, 'tcx>) -> Self {
12261342
// Create a fresh builder from the crate context.
12271343
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(cx.llcx) };
@@ -1231,13 +1347,16 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
12311347
pub(crate) fn llfn(&self) -> &'ll Value {
12321348
unsafe { llvm::LLVMGetBasicBlockParent(self.llbb()) }
12331349
}
1350+
}
12341351

1352+
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
12351353
fn position_at_start(&mut self, llbb: &'ll BasicBlock) {
12361354
unsafe {
12371355
llvm::LLVMRustPositionBuilderAtStart(self.llbuilder, llbb);
12381356
}
12391357
}
1240-
1358+
}
1359+
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
12411360
fn align_metadata(&mut self, load: &'ll Value, align: Align) {
12421361
unsafe {
12431362
let md = [llvm::LLVMValueAsMetadata(self.cx.const_u64(align.bytes()))];
@@ -1259,7 +1378,8 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
12591378
self.set_metadata(inst, llvm::MD_unpredictable, md);
12601379
}
12611380
}
1262-
1381+
}
1382+
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
12631383
pub(crate) fn minnum(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
12641384
unsafe { llvm::LLVMRustBuildMinNum(self.llbuilder, lhs, rhs) }
12651385
}
@@ -1360,7 +1480,9 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
13601480
let ret = unsafe { llvm::LLVMBuildCatchRet(self.llbuilder, funclet.cleanuppad(), unwind) };
13611481
ret.expect("LLVM does not have support for catchret")
13621482
}
1483+
}
13631484

1485+
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
13641486
fn check_call<'b>(
13651487
&mut self,
13661488
typ: &str,
@@ -1401,11 +1523,13 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
14011523

14021524
Cow::Owned(casted_args)
14031525
}
1404-
1526+
}
1527+
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
14051528
pub(crate) fn va_arg(&mut self, list: &'ll Value, ty: &'ll Type) -> &'ll Value {
14061529
unsafe { llvm::LLVMBuildVAArg(self.llbuilder, list, ty, UNNAMED) }
14071530
}
1408-
1531+
}
1532+
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
14091533
pub(crate) fn call_intrinsic(&mut self, intrinsic: &str, args: &[&'ll Value]) -> &'ll Value {
14101534
let (ty, f) = self.cx.get_intrinsic(intrinsic);
14111535
self.call(ty, None, None, f, args, None, None)
@@ -1423,7 +1547,8 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
14231547

14241548
self.call_intrinsic(intrinsic, &[self.cx.const_u64(size), ptr]);
14251549
}
1426-
1550+
}
1551+
impl<'a, 'll, CX: Borrow<SimpleCx<'ll>>> GenericBuilder<'a, 'll, CX> {
14271552
pub(crate) fn phi(
14281553
&mut self,
14291554
ty: &'ll Type,
@@ -1443,7 +1568,8 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
14431568
llvm::LLVMAddIncoming(phi, &val, &bb, 1 as c_uint);
14441569
}
14451570
}
1446-
1571+
}
1572+
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
14471573
fn fptoint_sat(&mut self, signed: bool, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
14481574
let src_ty = self.cx.val_ty(val);
14491575
let (float_ty, int_ty, vector_length) = if self.cx.type_kind(src_ty) == TypeKind::Vector {

Diff for: compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+11-14
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,19 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::back::write::ModuleConfig;
6-
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
76
use rustc_errors::FatalError;
8-
use rustc_middle::ty::TyCtxt;
97
use rustc_session::config::Lto;
108
use tracing::{debug, trace};
119

1210
use crate::back::write::{llvm_err, llvm_optimize};
13-
use crate::builder::Builder;
14-
use crate::declare::declare_raw_fn;
11+
use crate::builder::SBuilder;
12+
use crate::context::SimpleCx;
13+
use crate::declare::declare_simple_fn;
1514
use crate::errors::LlvmError;
1615
use crate::llvm::AttributePlace::Function;
1716
use crate::llvm::{Metadata, True};
1817
use crate::value::Value;
19-
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm};
18+
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
2019

2120
fn get_params(fnc: &Value) -> Vec<&Value> {
2221
unsafe {
@@ -38,8 +37,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
3837
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
3938
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
4039
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
41-
fn generate_enzyme_call<'ll, 'tcx>(
42-
cx: &context::CodegenCx<'ll, 'tcx>,
40+
fn generate_enzyme_call<'ll>(
41+
cx: &SimpleCx<'ll>,
4342
fn_to_diff: &'ll Value,
4443
outer_fn: &'ll Value,
4544
attrs: AutoDiffAttrs,
@@ -112,7 +111,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
112111
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
113112
// think a bit more about what should go here.
114113
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
115-
let ad_fn = declare_raw_fn(
114+
let ad_fn = declare_simple_fn(
116115
cx,
117116
&ad_name,
118117
llvm::CallConv::try_from(cc).expect("invalid callconv"),
@@ -132,7 +131,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
132131
llvm::LLVMRustEraseInstFromParent(br);
133132

134133
let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
135-
let mut builder = Builder::build(cx, entry);
134+
let mut builder = SBuilder::build(cx, entry);
136135

137136
let num_args = llvm::LLVMCountParams(&fn_to_diff);
138137
let mut args = Vec::with_capacity(num_args as usize + 1);
@@ -236,7 +235,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
236235
}
237236
}
238237

239-
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
238+
let call = builder.call(enzyme_ty, ad_fn, &args, None);
240239

241240
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
242241
// metadata attachted to it, but we just created this code oota. Given that the
@@ -274,10 +273,9 @@ fn generate_enzyme_call<'ll, 'tcx>(
274273
}
275274
}
276275

277-
pub(crate) fn differentiate<'ll, 'tcx>(
276+
pub(crate) fn differentiate<'ll>(
278277
module: &'ll ModuleCodegen<ModuleLlvm>,
279278
cgcx: &CodegenContext<LlvmCodegenBackend>,
280-
tcx: TyCtxt<'tcx>,
281279
diff_items: Vec<AutoDiffItem>,
282280
config: &ModuleConfig,
283281
) -> Result<(), FatalError> {
@@ -286,8 +284,7 @@ pub(crate) fn differentiate<'ll, 'tcx>(
286284
}
287285

288286
let diag_handler = cgcx.create_dcx();
289-
let (_, cgus) = tcx.collect_and_partition_mono_items(());
290-
let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
287+
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
291288

292289
// Before dumping the module, we want all the TypeTrees to become part of the module.
293290
for item in diff_items.iter() {

0 commit comments

Comments
 (0)