Skip to content

Commit ae648ea

Browse files
committed
Introduce SimpleCx and SBuilder as alternatives without 'tcx
1 parent a48e7b0 commit ae648ea

File tree

9 files changed

+218
-32
lines changed

9 files changed

+218
-32
lines changed

Diff for: compiler/rustc_codegen_gcc/src/lib.rs

-1
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,6 @@ impl WriteBackendMethods for GccCodegenBackend {
442442
}
443443
fn autodiff(
444444
_cgcx: &CodegenContext<Self>,
445-
_tcx: TyCtxt<'_>,
446445
_module: &ModuleCodegen<Self::Module>,
447446
_diff_fncs: Vec<AutoDiffItem>,
448447
_config: &ModuleConfig,

Diff for: compiler/rustc_codegen_llvm/src/builder.rs

+127-1
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,138 @@ use tracing::{debug, instrument};
3131
use crate::abi::FnAbiLlvmExt;
3232
use crate::attributes;
3333
use crate::common::Funclet;
34-
use crate::context::CodegenCx;
34+
use crate::context::{CodegenCx, SimpleCx};
3535
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True};
3636
use crate::type_::Type;
3737
use crate::type_of::LayoutLlvmExt;
3838
use crate::value::Value;
3939

40+
// All Builders must have an llfn associated with them
41+
#[must_use]
42+
pub(crate) struct SBuilder<'a, 'll> {
43+
pub llbuilder: &'ll mut llvm::Builder<'ll>,
44+
pub cx: &'a SimpleCx<'ll>,
45+
}
46+
47+
impl Drop for SBuilder<'_, '_> {
48+
fn drop(&mut self) {
49+
unsafe {
50+
llvm::LLVMDisposeBuilder(&mut *(self.llbuilder as *mut _));
51+
}
52+
}
53+
}
54+
55+
impl<'a, 'll> SBuilder<'a, 'll> {
56+
fn build(cx: &'a SimpleCx<'ll>, llbb: &'ll BasicBlock) -> SBuilder<'a, 'll> {
57+
let bx = SBuilder::with_scx(cx);
58+
unsafe {
59+
llvm::LLVMPositionBuilderAtEnd(bx.llbuilder, llbb);
60+
}
61+
bx
62+
}
63+
}
64+
65+
impl<'a, 'll> SBuilder<'a, 'll> {
66+
fn with_scx(scx: &'a SimpleCx<'ll>) -> Self {
67+
// Create a fresh builder from the simple context.
68+
let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(scx.llcx) };
69+
SBuilder { llbuilder, cx: scx }
70+
}
71+
72+
pub(crate) fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
73+
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
74+
}
75+
76+
fn ret_void(&mut self) {
77+
unsafe {
78+
llvm::LLVMBuildRetVoid(self.llbuilder);
79+
}
80+
}
81+
82+
fn ret(&mut self, v: &'ll Value) {
83+
unsafe {
84+
llvm::LLVMBuildRet(self.llbuilder, v);
85+
}
86+
}
87+
88+
fn check_call<'b>(
89+
&mut self,
90+
typ: &str,
91+
fn_ty: &'ll Type,
92+
llfn: &'ll Value,
93+
args: &'b [&'ll Value],
94+
) -> Cow<'b, [&'ll Value]> {
95+
assert!(
96+
self.cx.type_kind(fn_ty) == TypeKind::Function,
97+
"builder::{typ} not passed a function, but {fn_ty:?}"
98+
);
99+
100+
let param_tys = self.cx.func_params_types(fn_ty);
101+
102+
let all_args_match = iter::zip(&param_tys, args.iter().map(|&v| self.cx.val_ty(v)))
103+
.all(|(expected_ty, actual_ty)| *expected_ty == actual_ty);
104+
105+
if all_args_match {
106+
return Cow::Borrowed(args);
107+
}
108+
109+
let casted_args: Vec<_> = iter::zip(param_tys, args)
110+
.enumerate()
111+
.map(|(i, (expected_ty, &actual_val))| {
112+
let actual_ty = self.cx.val_ty(actual_val);
113+
if expected_ty != actual_ty {
114+
debug!(
115+
"type mismatch in function call of {:?}. \
116+
Expected {:?} for param {}, got {:?}; injecting bitcast",
117+
llfn, expected_ty, i, actual_ty
118+
);
119+
self.bitcast(actual_val, expected_ty)
120+
} else {
121+
actual_val
122+
}
123+
})
124+
.collect();
125+
126+
Cow::Owned(casted_args)
127+
}
128+
129+
// This is a simplified version of the call when using the full Builder.
130+
// It can not use any tcx related arguments
131+
fn call(
132+
&mut self,
133+
llty: &'ll Type,
134+
llfn: &'ll Value,
135+
args: &[&'ll Value],
136+
funclet: Option<&Funclet<'ll>>,
137+
) -> &'ll Value {
138+
debug!("call {:?} with args ({:?})", llfn, args);
139+
140+
let args = self.check_call("call", llty, llfn, args);
141+
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
142+
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
143+
if let Some(funclet_bundle) = funclet_bundle {
144+
bundles.push(funclet_bundle);
145+
}
146+
147+
let call = unsafe {
148+
llvm::LLVMBuildCallWithOperandBundles(
149+
self.llbuilder,
150+
llty,
151+
llfn,
152+
args.as_ptr() as *const &llvm::Value,
153+
args.len() as c_uint,
154+
bundles.as_ptr(),
155+
bundles.len() as c_uint,
156+
c"".as_ptr(),
157+
)
158+
};
159+
call
160+
}
161+
162+
}
163+
164+
165+
40166
// All Builders must have an llfn associated with them
41167
#[must_use]
42168
pub(crate) struct Builder<'a, 'll, 'tcx> {

Diff for: compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+11-14
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,19 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::back::write::ModuleConfig;
6-
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
76
use rustc_errors::FatalError;
8-
use rustc_middle::ty::TyCtxt;
97
use rustc_session::config::Lto;
108
use tracing::{debug, trace};
119

1210
use crate::back::write::{llvm_err, llvm_optimize};
13-
use crate::builder::Builder;
14-
use crate::declare::declare_raw_fn;
11+
use crate::builder::SBuilder;
12+
use crate::context::SimpleCx;
13+
use crate::declare::declare_simple_fn;
1514
use crate::errors::LlvmError;
1615
use crate::llvm::AttributePlace::Function;
1716
use crate::llvm::{Metadata, True};
1817
use crate::value::Value;
19-
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm};
18+
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
2019

2120
fn get_params(fnc: &Value) -> Vec<&Value> {
2221
unsafe {
@@ -38,8 +37,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
3837
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
3938
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
4039
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
41-
fn generate_enzyme_call<'ll, 'tcx>(
42-
cx: &context::CodegenCx<'ll, 'tcx>,
40+
fn generate_enzyme_call<'ll>(
41+
cx: &SimpleCx<'ll>,
4342
fn_to_diff: &'ll Value,
4443
outer_fn: &'ll Value,
4544
attrs: AutoDiffAttrs,
@@ -112,7 +111,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
112111
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
113112
// think a bit more about what should go here.
114113
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
115-
let ad_fn = declare_raw_fn(
114+
let ad_fn = declare_simple_fn(
116115
cx,
117116
&ad_name,
118117
llvm::CallConv::try_from(cc).expect("invalid callconv"),
@@ -132,7 +131,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
132131
llvm::LLVMRustEraseInstFromParent(br);
133132

134133
let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
135-
let mut builder = Builder::build(cx, entry);
134+
let mut builder = SBuilder::build(cx, entry);
136135

137136
let num_args = llvm::LLVMCountParams(&fn_to_diff);
138137
let mut args = Vec::with_capacity(num_args as usize + 1);
@@ -236,7 +235,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
236235
}
237236
}
238237

239-
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
238+
let call = builder.call(enzyme_ty, ad_fn, &args, None);
240239

241240
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
242241
// metadata attachted to it, but we just created this code oota. Given that the
@@ -274,10 +273,9 @@ fn generate_enzyme_call<'ll, 'tcx>(
274273
}
275274
}
276275

277-
pub(crate) fn differentiate<'ll, 'tcx>(
276+
pub(crate) fn differentiate<'ll>(
278277
module: &'ll ModuleCodegen<ModuleLlvm>,
279278
cgcx: &CodegenContext<LlvmCodegenBackend>,
280-
tcx: TyCtxt<'tcx>,
281279
diff_items: Vec<AutoDiffItem>,
282280
config: &ModuleConfig,
283281
) -> Result<(), FatalError> {
@@ -286,8 +284,7 @@ pub(crate) fn differentiate<'ll, 'tcx>(
286284
}
287285

288286
let diag_handler = cgcx.create_dcx();
289-
let (_, cgus) = tcx.collect_and_partition_mono_items(());
290-
let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm);
287+
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
291288

292289
// Before dumping the module, we want all the TypeTrees to become part of the module.
293290
for item in diff_items.iter() {

Diff for: compiler/rustc_codegen_llvm/src/context.rs

+45-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::str;
66
use rustc_abi::{HasDataLayout, TargetDataLayout, VariantIdx};
77
use rustc_codegen_ssa::back::versioned_llvm_target;
88
use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh};
9+
use rustc_codegen_ssa::common::TypeKind;
910
use rustc_codegen_ssa::errors as ssa_errors;
1011
use rustc_codegen_ssa::traits::*;
1112
use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN};
@@ -30,23 +31,42 @@ use smallvec::SmallVec;
3031

3132
use crate::back::write::to_llvm_code_model;
3233
use crate::callee::get_fn;
33-
use crate::common::AsCCharPtr;
34+
use crate::common::{self, AsCCharPtr};
3435
use crate::debuginfo::metadata::apply_vcall_visibility_metadata;
3536
use crate::llvm::{Metadata, MetadataType};
3637
use crate::type_::Type;
3738
use crate::value::Value;
3839
use crate::{attributes, coverageinfo, debuginfo, llvm, llvm_util};
3940

41+
use std::ops::Deref;
42+
// A tcx (and related cache datastructures) can't be move between threads.
43+
// However, there are various cx related functions which we want to be available to the builder and
44+
// other compiler pieces. Here we define a small subset which has enough information and can be
45+
// moved around more freely.
46+
pub(crate) struct SimpleCx<'ll> {
47+
pub llmod: &'ll llvm::Module,
48+
pub llcx: &'ll llvm::Context,
49+
}
50+
51+
impl<'ll, 'tcx> Deref for CodegenCx<'ll, 'tcx> {
52+
type Target = SimpleCx<'ll>;
53+
54+
#[inline]
55+
fn deref(&self) -> &Self::Target {
56+
&self.scx
57+
}
58+
}
59+
60+
4061
/// There is one `CodegenCx` per codegen unit. Each one has its own LLVM
4162
/// `llvm::Context` so that several codegen units may be processed in parallel.
4263
/// All other LLVM data structures in the `CodegenCx` are tied to that `llvm::Context`.
4364
pub(crate) struct CodegenCx<'ll, 'tcx> {
4465
pub tcx: TyCtxt<'tcx>,
66+
pub scx: SimpleCx<'ll>,
4567
pub use_dll_storage_attrs: bool,
4668
pub tls_model: llvm::ThreadLocalMode,
4769

48-
pub llmod: &'ll llvm::Module,
49-
pub llcx: &'ll llvm::Context,
5070
pub codegen_unit: &'tcx CodegenUnit<'tcx>,
5171

5272
/// Cache instances of monomorphic and polymorphic items
@@ -553,10 +573,9 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
553573

554574
CodegenCx {
555575
tcx,
576+
scx: SimpleCx { llcx, llmod },
556577
use_dll_storage_attrs,
557578
tls_model,
558-
llmod,
559-
llcx,
560579
codegen_unit,
561580
instances: Default::default(),
562581
vtables: Default::default(),
@@ -600,6 +619,22 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
600619
llvm::set_section(g, c"llvm.metadata");
601620
}
602621
}
622+
} impl<'ll> SimpleCx<'ll> {
623+
624+
625+
pub(crate) fn func_params_types(&self, ty: &'ll Type) -> Vec<&'ll Type> {
626+
unsafe {
627+
let n_args = llvm::LLVMCountParamTypes(ty) as usize;
628+
let mut args = Vec::with_capacity(n_args);
629+
llvm::LLVMGetParamTypes(ty, args.as_mut_ptr());
630+
args.set_len(n_args);
631+
args
632+
}
633+
}
634+
635+
pub(crate) fn val_ty(&self, v: &'ll Value) -> &'ll Type {
636+
common::val_ty(v)
637+
}
603638

604639
pub(crate) fn get_metadata_value(&self, metadata: &'ll Metadata) -> &'ll Value {
605640
unsafe { llvm::LLVMMetadataAsValue(self.llcx, metadata) }
@@ -625,6 +660,11 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
625660
llvm::LLVMMDStringInContext2(self.llcx, name.as_ptr() as *const c_char, name.len())
626661
})
627662
}
663+
664+
pub(crate) fn type_kind(&self, ty: &'ll Type) -> TypeKind {
665+
unsafe { llvm::LLVMRustGetTypeKind(ty).to_generic() }
666+
}
667+
628668
}
629669

630670
impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {

Diff for: compiler/rustc_codegen_llvm/src/declare.rs

+30-3
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,46 @@ use tracing::debug;
2121

2222
use crate::abi::{FnAbi, FnAbiLlvmExt};
2323
use crate::common::AsCCharPtr;
24-
use crate::context::CodegenCx;
24+
use crate::context::{CodegenCx, SimpleCx};
2525
use crate::llvm::AttributePlace::Function;
2626
use crate::llvm::Visibility;
2727
use crate::type_::Type;
2828
use crate::value::Value;
2929
use crate::{attributes, llvm};
3030

31+
32+
33+
34+
/// Declare a function with a SimpleCx.
35+
///
36+
/// If there’s a value with the same name already declared, the function will
37+
/// update the declaration and return existing Value instead.
38+
pub(crate) fn declare_simple_fn<'ll>(
39+
cx: &SimpleCx<'ll>,
40+
name: &str,
41+
callconv: llvm::CallConv,
42+
unnamed: llvm::UnnamedAddr,
43+
visibility: llvm::Visibility,
44+
ty: &'ll Type,
45+
) -> &'ll Value {
46+
debug!("declare_raw_fn(name={:?}, ty={:?})", name, ty);
47+
let llfn = unsafe {
48+
llvm::LLVMRustGetOrInsertFunction(cx.llmod, name.as_c_char_ptr(), name.len(), ty)
49+
};
50+
51+
llvm::SetFunctionCallConv(llfn, callconv);
52+
llvm::SetUnnamedAddress(llfn, unnamed);
53+
llvm::set_visibility(llfn, visibility);
54+
55+
llfn
56+
}
57+
3158
/// Declare a function.
3259
///
3360
/// If there’s a value with the same name already declared, the function will
3461
/// update the declaration and return existing Value instead.
35-
pub(crate) fn declare_raw_fn<'ll>(
36-
cx: &CodegenCx<'ll, '_>,
62+
pub(crate) fn declare_raw_fn<'ll, 'tcx>(
63+
cx: &CodegenCx<'ll, 'tcx>,
3764
name: &str,
3865
callconv: llvm::CallConv,
3966
unnamed: llvm::UnnamedAddr,

0 commit comments

Comments
 (0)