Skip to content

Commit cfb9360

Browse files
committed
wip
1 parent a48e7b0 commit cfb9360

File tree

8 files changed

+262
-49
lines changed

8 files changed

+262
-49
lines changed

Diff for: compiler/rustc_codegen_llvm/src/builder.rs

+151-12
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,151 @@ use tracing::{debug, instrument};
3131
use crate::abi::FnAbiLlvmExt;
3232
use crate::attributes;
3333
use crate::common::Funclet;
34-
use crate::context::CodegenCx;
34+
use crate::context::{CodegenCx, SimpleCx};
3535
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True};
3636
use crate::type_::Type;
3737
use crate::type_of::LayoutLlvmExt;
3838
use crate::value::Value;
3939

40+
// All Builders must have an llfn associated with them
41+
#[must_use]
42+
pub(crate) struct SBuilder<'a, 'll> {
43+
pub llbuilder: &'ll mut llvm::Builder<'ll>,
44+
pub cx: &'a SimpleCx<'ll>,
45+
}
46+
47+
impl Drop for SBuilder<'_, '_> {
48+
fn drop(&mut self) {
49+
unsafe {
50+
llvm::LLVMDisposeBuilder(&mut *(self.llbuilder as *mut _));
51+
}
52+
}
53+
}
54+
55+
impl<'a, 'll> SBuilder<'a, 'll> {
56+
fn build(cx: &'a SimpleCx<'ll>, llbb: &'ll BasicBlock) -> SBuilder<'a, 'll> {
57+
let bx = SBuilder::with_scx(cx);
58+
unsafe {
59+
llvm::LLVMPositionBuilderAtEnd(bx.llbuilder, llbb);
60+
}
61+
bx
62+
}
63+
}
64+
65+
impl<'a, 'll> SBuilder<'a, 'll> {
66+
fn with_scx(scx: &'a SimpleCx<'ll>) -> Self {
67+
// Create a fresh builder from the simple context.
68+
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(scx.llcx) };
69+
SBuilder { llbuilder, cx: scx }
70+
}
71+
72+
pub(crate) fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
73+
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
74+
}
75+
76+
fn ret_void(&mut self) {
77+
unsafe {
78+
llvm::LLVMBuildRetVoid(self.llbuilder);
79+
}
80+
}
81+
82+
fn ret(&mut self, v: &'ll Value) {
83+
unsafe {
84+
llvm::LLVMBuildRet(self.llbuilder, v);
85+
}
86+
}
87+
88+
fn check_call<'b>(
89+
&mut self,
90+
typ: &str,
91+
fn_ty: &'ll Type,
92+
llfn: &'ll Value,
93+
args: &'b [&'ll Value],
94+
) -> Cow<'b, [&'ll Value]> {
95+
assert!(
96+
self.cx.type_kind(fn_ty) == TypeKind::Function,
97+
"builder::{typ} not passed a function, but {fn_ty:?}"
98+
);
99+
100+
let param_tys = self.cx.func_params_types(fn_ty);
101+
102+
let all_args_match = iter::zip(&param_tys, args.iter().map(|&v| self.cx.val_ty(v)))
103+
.all(|(expected_ty, actual_ty)| *expected_ty == actual_ty);
104+
105+
if all_args_match {
106+
return Cow::Borrowed(args);
107+
}
108+
109+
let casted_args: Vec<_> = iter::zip(param_tys, args)
110+
.enumerate()
111+
.map(|(i, (expected_ty, &actual_val))| {
112+
let actual_ty = self.cx.val_ty(actual_val);
113+
if expected_ty != actual_ty {
114+
debug!(
115+
"type mismatch in function call of {:?}. \
116+
Expected {:?} for param {}, got {:?}; injecting bitcast",
117+
llfn, expected_ty, i, actual_ty
118+
);
119+
self.bitcast(actual_val, expected_ty)
120+
} else {
121+
actual_val
122+
}
123+
})
124+
.collect();
125+
126+
Cow::Owned(casted_args)
127+
}
128+
129+
fn call(
130+
&mut self,
131+
llty: &'ll Type,
132+
//fn_attrs: Option<&CodegenFnAttrs>,
133+
//fn_abi: Option<&FnAbi<'tcx, Ty<'tcx>>>,
134+
llfn: &'ll Value,
135+
args: &[&'ll Value],
136+
funclet: Option<&Funclet<'ll>>,
137+
//instance: Option<Instance<'tcx>>,
138+
) -> &'ll Value {
139+
debug!("call {:?} with args ({:?})", llfn, args);
140+
141+
let args = self.check_call("call", llty, llfn, args);
142+
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
143+
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
144+
if let Some(funclet_bundle) = funclet_bundle {
145+
bundles.push(funclet_bundle);
146+
}
147+
148+
// Emit CFI pointer type membership test
149+
//self.cfi_type_test(fn_attrs, fn_abi, instance, llfn);
150+
151+
// Emit KCFI operand bundle
152+
//let kcfi_bundle = self.kcfi_operand_bundle(fn_attrs, fn_abi, instance, llfn);
153+
//if let Some(kcfi_bundle) = kcfi_bundle.as_deref() {
154+
// bundles.push(kcfi_bundle);
155+
//}
156+
157+
let call = unsafe {
158+
llvm::LLVMBuildCallWithOperandBundles(
159+
self.llbuilder,
160+
llty,
161+
llfn,
162+
args.as_ptr() as *const &llvm::Value,
163+
args.len() as c_uint,
164+
bundles.as_ptr(),
165+
bundles.len() as c_uint,
166+
c"".as_ptr(),
167+
)
168+
};
169+
//if let Some(fn_abi) = fn_abi {
170+
// fn_abi.apply_attrs_callsite(self, call);
171+
//}
172+
call
173+
}
174+
175+
}
176+
177+
178+
40179
// All Builders must have an llfn associated with them
41180
#[must_use]
42181
pub(crate) struct Builder<'a, 'll, 'tcx> {
@@ -55,7 +194,7 @@ impl Drop for Builder<'_, '_, '_> {
55194
/// Empty string, to be used where LLVM expects an instruction name, indicating
56195
/// that the instruction is to be left unnamed (i.e. numbered, in textual IR).
57196
// FIXME(eddyb) pass `&CStr` directly to FFI once it's a thin pointer.
58-
const UNNAMED: *const c_char = c"".as_ptr();
197+
pub(crate) const UNNAMED: *const c_char = c"".as_ptr();
59198

60199
impl<'ll, 'tcx> BackendTypes for Builder<'_, 'll, 'tcx> {
61200
type Value = <CodegenCx<'ll, 'tcx> as BackendTypes>::Value;
@@ -170,7 +309,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
170309
fn append_block(cx: &'a CodegenCx<'ll, 'tcx>, llfn: &'ll Value, name: &str) -> &'ll BasicBlock {
171310
unsafe {
172311
let name = SmallCStr::new(name);
173-
llvm::LLVMAppendBasicBlockInContext(cx.llcx, llfn, name.as_ptr())
312+
llvm::LLVMAppendBasicBlockInContext(cx.scx.llcx, llfn, name.as_ptr())
174313
}
175314
}
176315

@@ -621,14 +760,14 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
621760
llvm::LLVMValueAsMetadata(self.cx.const_uint_big(llty, range.start)),
622761
llvm::LLVMValueAsMetadata(self.cx.const_uint_big(llty, range.end.wrapping_add(1))),
623762
];
624-
let md = llvm::LLVMMDNodeInContext2(self.cx.llcx, md.as_ptr(), md.len());
763+
let md = llvm::LLVMMDNodeInContext2(self.cx.scx.llcx, md.as_ptr(), md.len());
625764
self.set_metadata(load, llvm::MD_range, md);
626765
}
627766
}
628767

629768
fn nonnull_metadata(&mut self, load: &'ll Value) {
630769
unsafe {
631-
let md = llvm::LLVMMDNodeInContext2(self.cx.llcx, ptr::null(), 0);
770+
let md = llvm::LLVMMDNodeInContext2(self.cx.scx.llcx, ptr::null(), 0);
632771
self.set_metadata(load, llvm::MD_nonnull, md);
633772
}
634773
}
@@ -678,7 +817,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
678817
//
679818
// [1]: https://llvm.org/docs/LangRef.html#store-instruction
680819
let one = llvm::LLVMValueAsMetadata(self.cx.const_i32(1));
681-
let md = llvm::LLVMMDNodeInContext2(self.cx.llcx, &one, 1);
820+
let md = llvm::LLVMMDNodeInContext2(self.cx.scx.llcx, &one, 1);
682821
self.set_metadata(store, llvm::MD_nontemporal, md);
683822
}
684823
}
@@ -1144,7 +1283,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11441283

11451284
fn set_invariant_load(&mut self, load: &'ll Value) {
11461285
unsafe {
1147-
let md = llvm::LLVMMDNodeInContext2(self.cx.llcx, ptr::null(), 0);
1286+
let md = llvm::LLVMMDNodeInContext2(self.cx.scx.llcx, ptr::null(), 0);
11481287
self.set_metadata(load, llvm::MD_invariant_load, md);
11491288
}
11501289
}
@@ -1209,7 +1348,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
12091348

12101349
fn apply_attrs_to_cleanup_callsite(&mut self, llret: &'ll Value) {
12111350
// Cleanup is always the cold path.
1212-
let cold_inline = llvm::AttributeKind::Cold.create_attr(self.llcx);
1351+
let cold_inline = llvm::AttributeKind::Cold.create_attr(self.scx.llcx);
12131352
attributes::apply_to_callsite(llret, llvm::AttributePlace::Function, &[cold_inline]);
12141353
}
12151354
}
@@ -1224,7 +1363,7 @@ impl<'ll> StaticBuilderMethods for Builder<'_, 'll, '_> {
12241363
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
12251364
fn with_cx(cx: &'a CodegenCx<'ll, 'tcx>) -> Self {
12261365
// Create a fresh builder from the crate context.
1227-
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(cx.llcx) };
1366+
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(cx.scx.llcx) };
12281367
Builder { llbuilder, cx }
12291368
}
12301369

@@ -1241,21 +1380,21 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
12411380
fn align_metadata(&mut self, load: &'ll Value, align: Align) {
12421381
unsafe {
12431382
let md = [llvm::LLVMValueAsMetadata(self.cx.const_u64(align.bytes()))];
1244-
let md = llvm::LLVMMDNodeInContext2(self.cx.llcx, md.as_ptr(), md.len());
1383+
let md = llvm::LLVMMDNodeInContext2(self.cx.scx.llcx, md.as_ptr(), md.len());
12451384
self.set_metadata(load, llvm::MD_align, md);
12461385
}
12471386
}
12481387

12491388
fn noundef_metadata(&mut self, load: &'ll Value) {
12501389
unsafe {
1251-
let md = llvm::LLVMMDNodeInContext2(self.cx.llcx, ptr::null(), 0);
1390+
let md = llvm::LLVMMDNodeInContext2(self.cx.scx.llcx, ptr::null(), 0);
12521391
self.set_metadata(load, llvm::MD_noundef, md);
12531392
}
12541393
}
12551394

12561395
pub(crate) fn set_unpredictable(&mut self, inst: &'ll Value) {
12571396
unsafe {
1258-
let md = llvm::LLVMMDNodeInContext2(self.cx.llcx, ptr::null(), 0);
1397+
let md = llvm::LLVMMDNodeInContext2(self.cx.scx.llcx, ptr::null(), 0);
12591398
self.set_metadata(inst, llvm::MD_unpredictable, md);
12601399
}
12611400
}

Diff for: compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+19-14
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,22 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::back::write::ModuleConfig;
6-
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
6+
//use rustc_codegen_ssa::traits::BuilderMethods;
77
use rustc_errors::FatalError;
8-
use rustc_middle::ty::TyCtxt;
8+
//use rustc_middle::ty::TyCtxt;
99
use rustc_session::config::Lto;
1010
use tracing::{debug, trace};
1111

1212
use crate::back::write::{llvm_err, llvm_optimize};
13-
use crate::builder::Builder;
14-
use crate::declare::declare_raw_fn;
13+
use crate::builder::SBuilder;
14+
//use crate::builder::{Builder, SBuilder};
15+
use crate::context::SimpleCx;
16+
use crate::declare::declare_simple_fn;
1517
use crate::errors::LlvmError;
1618
use crate::llvm::AttributePlace::Function;
1719
use crate::llvm::{Metadata, True};
1820
use crate::value::Value;
19-
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm};
21+
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
2022

2123
fn get_params(fnc: &Value) -> Vec<&Value> {
2224
unsafe {
@@ -38,8 +40,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
3840
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
3941
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
4042
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
41-
fn generate_enzyme_call<'ll, 'tcx>(
42-
cx: &context::CodegenCx<'ll, 'tcx>,
43+
fn generate_enzyme_call<'ll>(
44+
cx: &SimpleCx<'ll>,
4345
fn_to_diff: &'ll Value,
4446
outer_fn: &'ll Value,
4547
attrs: AutoDiffAttrs,
@@ -112,7 +114,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
112114
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
113115
// think a bit more about what should go here.
114116
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
115-
let ad_fn = declare_raw_fn(
117+
let ad_fn = declare_simple_fn(
116118
cx,
117119
&ad_name,
118120
llvm::CallConv::try_from(cc).expect("invalid callconv"),
@@ -132,7 +134,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
132134
llvm::LLVMRustEraseInstFromParent(br);
133135

134136
let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
135-
let mut builder = Builder::build(cx, entry);
137+
let mut builder = SBuilder::build(cx, entry);
136138

137139
let num_args = llvm::LLVMCountParams(&fn_to_diff);
138140
let mut args = Vec::with_capacity(num_args as usize + 1);
@@ -236,7 +238,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
236238
}
237239
}
238240

239-
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
241+
let call = builder.call(enzyme_ty, ad_fn, &args, None);
240242

241243
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
242244
// metadata attachted to it, but we just created this code oota. Given that the
@@ -274,20 +276,23 @@ fn generate_enzyme_call<'ll, 'tcx>(
274276
}
275277
}
276278

277-
pub(crate) fn differentiate<'ll, 'tcx>(
279+
pub(crate) fn differentiate<'ll>(
278280
module: &'ll ModuleCodegen<ModuleLlvm>,
279281
cgcx: &CodegenContext<LlvmCodegenBackend>,
280-
tcx: TyCtxt<'tcx>,
282+
//cx: SimpleCx<'ll>,
281283
diff_items: Vec<AutoDiffItem>,
282284
config: &ModuleConfig,
283285
) -> Result<(), FatalError> {
284286
for item in &diff_items {
285287
trace!("{}", item);
286288
}
287289

290+
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
291+
288292
let diag_handler = cgcx.create_dcx();
289-
let (_, cgus) = tcx.collect_and_partition_mono_items(());
290-
let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
293+
//let cx = context::SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
294+
//let (_, cgus) = tcx.collect_and_partition_mono_items(());
295+
//let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
291296

292297
// Before dumping the module, we want all the TypeTrees to become part of the module.
293298
for item in diff_items.iter() {

0 commit comments

Comments
 (0)