Skip to content

Commit 3872f89

Browse files
authored
Handle dupnoneed in forward mode (rust-lang#396)
* Handle dupnoneed in forward mode * Fix allocs with more than one parameter (rust-lang#397) * Add test
1 parent 394992e commit 3872f89

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7455,7 +7455,11 @@ class AdjointGenerator
74557455
IRBuilder<> Builder2(&call);
74567456
getForwardBuilder(Builder2);
74577457

7458-
SmallVector<Value *, 2> args = {orig->getArgOperand(0)};
7458+
SmallVector<Value *, 2> args;
7459+
for (unsigned i = 0; i < orig->getNumArgOperands(); ++i) {
7460+
auto arg = orig->getArgOperand(i);
7461+
args.push_back(gutils->getNewFromOriginal(arg));
7462+
}
74597463
CallInst *CI = Builder2.CreateCall(orig->getFunctionType(),
74607464
orig->getCalledFunction(), args);
74617465
CI->setAttributes(orig->getAttributes());

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3848,7 +3848,9 @@ Function *EnzymeLogic::CreateForwardDiff(
38483848

38493849
for (size_t i = 0; i < constant_args.size(); ++i) {
38503850
auto arg = constant_args[i];
3851-
if (arg == DIFFE_TYPE::DUP_ARG) {
3851+
switch (arg) {
3852+
case DIFFE_TYPE::DUP_ARG:
3853+
case DIFFE_TYPE::DUP_NONEED: {
38523854
newArgs += 1;
38533855
auto pri = gutils->oldFunc->arg_begin() + i;
38543856
auto dif = newArgs;
@@ -3857,6 +3859,12 @@ Function *EnzymeLogic::CreateForwardDiff(
38573859
IRBuilder<> Builder(&BB.front());
38583860

38593861
gutils->setDiffe(pri, dif, Builder);
3862+
break;
3863+
}
3864+
case DIFFE_TYPE::CONSTANT:
3865+
break;
3866+
case DIFFE_TYPE::OUT_DIFF:
3867+
report_fatal_error("unsupported DIFFE_TYPE");
38603868
}
38613869
newArgs += 1;
38623870
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -simplifycfg -S | FileCheck %s
2+
3+
4+
@enzyme_dupnoneed = dso_local global i32 0, align 4
5+
6+
define dso_local double @f(double %x, i64 %arg) {
7+
entry:
8+
%call = call noalias i8* @calloc(i64 8, i64 %arg)
9+
%0 = bitcast i8* %call to double*
10+
store double %x, double* %0, align 8
11+
%1 = load double, double* %0, align 8
12+
ret double %1
13+
}
14+
15+
declare dso_local noalias i8* @calloc(i64, i64)
16+
17+
define dso_local double @df(double %x) {
18+
entry:
19+
%x.addr = alloca double, align 8
20+
store double %x, double* %x.addr, align 8
21+
%0 = load i32, i32* @enzyme_dupnoneed, align 4
22+
%1 = load double, double* %x.addr, align 8
23+
%call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double,i64)* @f to i8*), i32 %0, double %1, double 1.000000e+00, i64 1)
24+
ret double %call
25+
}
26+
27+
declare dso_local double @__enzyme_fwddiff(i8*, ...)
28+
29+
30+
; CHECK: define internal double @fwddiffef(double %x, double %"x'", i64 %arg)
31+
; CHECK-NEXT: entry:
32+
; CHECK-NEXT: %call = call noalias i8* @calloc(i64 8, i64 %arg)
33+
; CHECK-NEXT: %0 = call noalias i8* @calloc(i64 8, i64 %arg)
34+
; CHECK-NEXT: %"'ipc" = bitcast i8* %0 to double*
35+
; CHECK-NEXT: %1 = bitcast i8* %call to double*
36+
; CHECK-NEXT: store double %x, double* %1, align 8
37+
; CHECK-NEXT: store double %"x'", double* %"'ipc", align 8
38+
; CHECK-NEXT: %2 = load double, double* %"'ipc", align 8
39+
; CHECK-NEXT: ret double %2
40+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)