@@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
201
201
}
202
202
203
203
if attrs. width == 1 {
204
- todo ! ( "Handle sret for scalar ad" ) ;
204
+ // Enzyme returns a struct of style:
205
+ // `{ original_ret(if requested), float, float, ... }`
206
+ let mut struct_elements = vec ! [ ] ;
207
+ if attrs. has_primal_ret ( ) {
208
+ struct_elements. push ( inner_ret_ty) ;
209
+ }
210
+ // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
211
+ // and therefore part of the return struct.
212
+ let param_tys = cx. func_params_types ( fn_ty) ;
213
+ for ( act, param_ty) in attrs. input_activity . iter ( ) . zip ( param_tys) {
214
+ if matches ! ( act, DiffActivity :: Active ) {
215
+ // Now find the float type at position i based on the fn_ty,
216
+ // to know what (f16/f32/f64/...) to add to the struct.
217
+ struct_elements. push ( param_ty) ;
218
+ }
219
+ }
220
+ ret_ty = cx. type_struct ( & struct_elements, false ) ;
205
221
} else {
206
222
// First we check if we also have to deal with the primal return.
207
223
match attrs. mode {
@@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
388
404
// now store the result of the enzyme call into the sret pointer.
389
405
let sret_ptr = outer_args[ 0 ] ;
390
406
let call_ty = cx. val_ty ( call) ;
391
- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
407
+ if attrs. width == 1 {
408
+ assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Struct ) ;
409
+ } else {
410
+ assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
411
+ }
392
412
llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
393
413
}
394
414
builder. ret_void ( ) ;
0 commit comments