Skip to content

Commit d6467d3

Browse files
committed
handle sret for scalar autodiff
1 parent 2fa8b11 commit d6467d3

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

Diff for: compiler/rustc_ast/src/expand/autodiff_attrs.rs

+6
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ pub struct AutoDiffAttrs {
9292
pub input_activity: Vec<DiffActivity>,
9393
}
9494

95+
impl AutoDiffAttrs {
96+
pub fn has_primal_ret(&self) -> bool {
97+
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
98+
}
99+
}
100+
95101
impl DiffMode {
96102
pub fn is_rev(&self) -> bool {
97103
matches!(self, DiffMode::Reverse)

Diff for: compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+22-2
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
201201
}
202202

203203
if attrs.width == 1 {
204-
todo!("Handle sret for scalar ad");
204+
// Enzyme returns a struct of style:
205+
// `{ original_ret(if requested), float, float, ... }`
206+
let mut struct_elements = vec![];
207+
if attrs.has_primal_ret() {
208+
struct_elements.push(inner_ret_ty);
209+
}
210+
// Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
211+
// and therefore part of the return struct.
212+
let param_tys = cx.func_params_types(fn_ty);
213+
for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
214+
if matches!(act, DiffActivity::Active) {
215+
// Now find the float type at position i based on the fn_ty,
216+
// to know what (f16/f32/f64/...) to add to the struct.
217+
struct_elements.push(param_ty);
218+
}
219+
}
220+
ret_ty = cx.type_struct(&struct_elements, false);
205221
} else {
206222
// First we check if we also have to deal with the primal return.
207223
match attrs.mode {
@@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
388404
// now store the result of the enzyme call into the sret pointer.
389405
let sret_ptr = outer_args[0];
390406
let call_ty = cx.val_ty(call);
391-
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
407+
if attrs.width == 1 {
408+
assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
409+
} else {
410+
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
411+
}
392412
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
393413
}
394414
builder.ret_void();

0 commit comments

Comments
 (0)