Skip to content

Commit a449451

Browse files
authored
ForwardMode phi (rust-lang#198)
* added derivative of phi for forward mode * stop using addToDiffe in forward mode * added tests
1 parent bc04d6f commit a449451

File tree

19 files changed

+1328
-52
lines changed

19 files changed

+1328
-52
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 82 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,50 @@ class AdjointGenerator
593593
eraseIfUnused(gep);
594594
}
595595

596-
void visitPHINode(llvm::PHINode &phi) { eraseIfUnused(phi); }
596+
void visitPHINode(llvm::PHINode &phi) {
597+
eraseIfUnused(phi);
598+
if (gutils->isConstantInstruction(&phi))
599+
return;
600+
601+
switch (Mode) {
602+
case DerivativeMode::ReverseModePrimal:
603+
case DerivativeMode::ReverseModeGradient:
604+
case DerivativeMode::ReverseModeCombined: {
605+
return;
606+
}
607+
case DerivativeMode::ForwardMode: {
608+
break;
609+
}
610+
}
611+
612+
BasicBlock *oBB = phi.getParent();
613+
BasicBlock *nBB = gutils->getNewFromOriginal(oBB);
614+
615+
IRBuilder<> diffeBuilder(nBB->getFirstNonPHI());
616+
diffeBuilder.setFastMathFlags(getFast());
617+
618+
IRBuilder<> phiBuilder(&phi);
619+
getForwardBuilder(phiBuilder);
620+
621+
auto newPhi = phiBuilder.CreatePHI(phi.getType(), 1, phi.getName() + "'");
622+
for (unsigned int i = 0; i < phi.getNumIncomingValues(); ++i) {
623+
auto val = phi.getIncomingValue(i);
624+
auto block = phi.getIncomingBlock(i);
625+
626+
auto newBlock = gutils->getNewFromOriginal(block);
627+
IRBuilder<> pBuilder(newBlock->getTerminator());
628+
pBuilder.setFastMathFlags(getFast());
629+
630+
if (gutils->isConstantValue(val)) {
631+
newPhi->addIncoming(Constant::getNullValue(val->getType()), newBlock);
632+
} else {
633+
auto diff = diffe(val, pBuilder);
634+
newPhi->addIncoming(diff, newBlock);
635+
}
636+
}
637+
638+
setDiffe(&phi, newPhi, diffeBuilder);
639+
}
597640

598641
void visitCastInst(llvm::CastInst &I) {
599642
eraseIfUnused(I);
@@ -1078,8 +1121,7 @@ class AdjointGenerator
10781121
std::vector<SelectInst *> addToDiffe(Value *val, Value *dif,
10791122
IRBuilder<> &Builder, Type *T) {
10801123
assert(Mode == DerivativeMode::ReverseModeGradient ||
1081-
Mode == DerivativeMode::ReverseModeCombined ||
1082-
Mode == DerivativeMode::ForwardMode);
1124+
Mode == DerivativeMode::ReverseModeCombined);
10831125
return ((DiffeGradientUtils *)gutils)->addToDiffe(val, dif, Builder, T);
10841126
}
10851127

@@ -1510,61 +1552,66 @@ class AdjointGenerator
15101552
Value *dif0 = constantval0 ? nullptr : diffe(orig_op0, Builder2);
15111553
Value *dif1 = constantval1 ? nullptr : diffe(orig_op1, Builder2);
15121554

1513-
Type *addingType = BO.getType();
1514-
15151555
switch (BO.getOpcode()) {
15161556
case Instruction::FMul: {
1517-
if (!constantval0) {
1557+
if (!constantval0 && !constantval1) {
1558+
Value *idiff0 =
1559+
Builder2.CreateFMul(dif0, gutils->getNewFromOriginal(orig_op1));
1560+
Value *idiff1 =
1561+
Builder2.CreateFMul(dif1, gutils->getNewFromOriginal(orig_op0));
1562+
Value *diff = Builder2.CreateFAdd(idiff0, idiff1);
1563+
setDiffe(&BO, diff, Builder2);
1564+
} else if (!constantval0) {
15181565
Value *idiff0 =
15191566
Builder2.CreateFMul(dif0, gutils->getNewFromOriginal(orig_op1));
15201567
setDiffe(&BO, idiff0, Builder2);
1521-
}
1522-
1523-
if (!constantval1) {
1568+
} else if (!constantval1) {
15241569
Value *idiff1 =
15251570
Builder2.CreateFMul(dif1, gutils->getNewFromOriginal(orig_op0));
1526-
addToDiffe(&BO, idiff1, Builder2, addingType);
1571+
setDiffe(&BO, idiff1, Builder2);
15271572
}
15281573
break;
15291574
}
15301575
case Instruction::FAdd: {
1531-
if (!constantval0) {
1532-
addToDiffe(&BO, dif0, Builder2, addingType);
1533-
}
1534-
1535-
if (!constantval1) {
1536-
addToDiffe(&BO, dif1, Builder2, addingType);
1576+
if (!constantval0 && !constantval1) {
1577+
Value *diff = Builder2.CreateFAdd(dif0, dif1);
1578+
setDiffe(&BO, diff, Builder2);
1579+
} else if (!constantval0) {
1580+
setDiffe(&BO, dif0, Builder2);
1581+
} else if (!constantval1) {
1582+
setDiffe(&BO, dif1, Builder2);
15371583
}
15381584
break;
15391585
}
15401586
case Instruction::FSub: {
1541-
if (!constantval0) {
1542-
addToDiffe(&BO, dif0, Builder2, addingType);
1543-
}
1544-
1545-
if (!constantval1) {
1546-
addToDiffe(&BO, Builder2.CreateFNeg(dif1), Builder2, addingType);
1587+
if (!constantval0 && !constantval1) {
1588+
Value *diff = Builder2.CreateFAdd(dif0, Builder2.CreateFNeg(dif1));
1589+
setDiffe(&BO, diff, Builder2);
1590+
} else if (!constantval0) {
1591+
setDiffe(&BO, dif0, Builder2);
1592+
} else if (!constantval1) {
1593+
setDiffe(&BO, Builder2.CreateFNeg(dif1), Builder2);
15471594
}
15481595
break;
15491596
}
15501597
case Instruction::FDiv: {
1551-
Value *idiff1;
1552-
if (!constantval0) {
1553-
idiff1 =
1598+
Value *idiff3 = nullptr;
1599+
if (!constantval0 && !constantval1) {
1600+
Value *idiff1 =
15541601
Builder2.CreateFMul(dif0, gutils->getNewFromOriginal(orig_op1));
1555-
} else {
1556-
idiff1 = ConstantFP::get(addingType, 0.0);
1557-
}
1558-
1559-
Value *idiff2;
1560-
if (!constantval1) {
1561-
idiff2 =
1602+
Value *idiff2 =
15621603
Builder2.CreateFMul(gutils->getNewFromOriginal(orig_op0), dif1);
1563-
} else {
1564-
idiff2 = ConstantFP::get(addingType, 0.0);
1604+
idiff3 = Builder2.CreateFSub(idiff1, idiff2);
1605+
} else if (!constantval0) {
1606+
Value *idiff1 =
1607+
Builder2.CreateFMul(dif0, gutils->getNewFromOriginal(orig_op1));
1608+
idiff3 = idiff1;
1609+
} else if (!constantval1) {
1610+
Value *idiff2 =
1611+
Builder2.CreateFMul(gutils->getNewFromOriginal(orig_op0), dif1);
1612+
idiff3 = Builder2.CreateFNeg(idiff2);
15651613
}
15661614

1567-
Value *idiff3 = Builder2.CreateFSub(idiff1, idiff2);
15681615
Value *idiff4 = Builder2.CreateFMul(gutils->getNewFromOriginal(orig_op1),
15691616
gutils->getNewFromOriginal(orig_op1));
15701617
Value *idiff5 = Builder2.CreateFDiv(idiff3, idiff4);
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -gvn -instsimplify -correlated-propagation -adce -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline norecurse nounwind uwtable
4+
define double @f(double* nocapture %x, i64 %n) #0 {
5+
entry:
6+
br label %loop
7+
8+
loop:
9+
%j = phi i64 [ %nj, %end ], [ 0, %entry ]
10+
%sum = phi double [ %nsum, %end ], [ 0.000000e+00, %entry ]
11+
%nj = add nsw nuw i64 %j, 1
12+
%g0 = getelementptr inbounds double, double* %x, i64 %j
13+
br label %body
14+
15+
body: ; preds = %entry, %for.cond.cleanup6
16+
%i = phi i64 [ %next, %body ], [ 0, %loop ]
17+
%gep = getelementptr inbounds double, double* %g0, i64 %i
18+
%ld = load double, double* %gep, align 8
19+
%cmp = fcmp oeq double %ld, 3.141592e+00
20+
%next = add nuw i64 %i, 1
21+
br i1 %cmp, label %body, label %end
22+
23+
end:
24+
%gep2 = getelementptr inbounds double, double* %x, i64 %i
25+
%ld2 = load double, double* %gep2, align 8
26+
%nsum = fadd double %ld2, %sum
27+
%cmp2 = icmp ne i64 %nj, 10
28+
br i1 %cmp2, label %loop, label %exit
29+
30+
exit:
31+
ret double %nsum
32+
}
33+
34+
; Function Attrs: noinline nounwind uwtable
35+
define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n) local_unnamed_addr #1 {
36+
entry:
37+
%call = tail call fast double @__enzyme_fwddiff(i8* bitcast (double (double*, i64)* @f to i8*), double* %x, double* %xp, i64 %n)
38+
ret double %call
39+
}
40+
41+
declare dso_local double @__enzyme_fwddiff(i8*, double*, double*, i64) local_unnamed_addr
42+
43+
attributes #0 = { noinline norecurse nounwind uwtable }
44+
attributes #1 = { noinline nounwind uwtable }
45+
46+
47+
; CHECK: define internal { double } @diffef(double* nocapture %x, double* nocapture %"x'", i64 %n)
48+
; CHECK-NEXT: entry:
49+
; CHECK-NEXT: br label %loop
50+
51+
; CHECK: loop: ; preds = %end, %entry
52+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %end ], [ 0, %entry ]
53+
; CHECK-NEXT: %"sum'" = phi {{(fast )?}}double [ %1, %end ], [ 0.000000e+00, %entry ]
54+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
55+
; CHECK-NEXT: %g0 = getelementptr inbounds double, double* %x, i64 %iv
56+
; CHECK-NEXT: br label %body
57+
58+
; CHECK: body: ; preds = %body, %loop
59+
; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %body ], [ 0, %loop ]
60+
; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1
61+
; CHECK-NEXT: %gep = getelementptr inbounds double, double* %g0, i64 %iv1
62+
; CHECK-NEXT: %ld = load double, double* %gep, align 8
63+
; CHECK-NEXT: %cmp = fcmp oeq double %ld, 0x400921FAFC8B007A
64+
; CHECK-NEXT: br i1 %cmp, label %body, label %end
65+
66+
; CHECK: end: ; preds = %body
67+
; CHECK-NEXT: %"gep2'ipg" = getelementptr inbounds double, double* %"x'", i64 %iv1
68+
; CHECK-NEXT: %0 = load double, double* %"gep2'ipg"
69+
; CHECK-NEXT: %1 = fadd fast double %0, %"sum'"
70+
; CHECK-NEXT: %cmp2 = icmp ne i64 %iv.next, 10
71+
; CHECK-NEXT: br i1 %cmp2, label %loop, label %exit
72+
73+
; CHECK: exit: ; preds = %end
74+
; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0
75+
; CHECK-NEXT: ret { double } %2
76+
; CHECK-NEXT: }
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -gvn -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline norecurse nounwind uwtable
4+
define double @f(double* nocapture %x, i64 %n) #0 {
5+
entry:
6+
br label %loop
7+
8+
loop:
9+
%j = phi i64 [ %nj, %end ], [ 0, %entry ]
10+
%sum = phi double [ %nsum, %end ], [ 0.000000e+00, %entry ]
11+
%nj = add nsw nuw i64 %j, 1
12+
%g0 = getelementptr inbounds double, double* %x, i64 %j
13+
br label %body
14+
15+
body: ; preds = %entry, %for.cond.cleanup6
16+
%i = phi i64 [ %next, %body ], [ 0, %loop ]
17+
%idx = phi i64 [ %nidx, %body ], [ 0, %loop ]
18+
%gep = getelementptr inbounds double, double* %g0, i64 %i
19+
%ld = load double, double* %gep, align 8
20+
%cmp = fcmp oeq double %ld, 3.141592e+00
21+
%next = add nuw i64 %i, 1
22+
%int = fptoui double %ld to i64
23+
%nidx = add nuw i64 %idx, %int
24+
br i1 %cmp, label %body, label %end
25+
26+
end:
27+
%gep2 = getelementptr inbounds double, double* %x, i64 %idx
28+
%ld2 = load double, double* %gep2, align 8
29+
%nsum = fadd double %ld2, %sum
30+
%cmp2 = icmp ne i64 %nj, 10
31+
br i1 %cmp2, label %loop, label %exit
32+
33+
exit:
34+
ret double %nsum
35+
}
36+
37+
; Function Attrs: noinline nounwind uwtable
38+
define dso_local double @dsumsquare(double* %x, double* %xp, i64 %n) local_unnamed_addr #1 {
39+
entry:
40+
%call = tail call fast double @__enzyme_fwddiff(i8* bitcast (double (double*, i64)* @f to i8*), double* %x, double* %xp, i64 %n)
41+
ret double %call
42+
}
43+
44+
declare dso_local double @__enzyme_fwddiff(i8*, double*, double*, i64) local_unnamed_addr
45+
46+
attributes #0 = { noinline norecurse nounwind uwtable }
47+
attributes #1 = { noinline nounwind uwtable }
48+
49+
50+
; CHECK: define internal { double } @diffef(double* nocapture %x, double* nocapture %"x'", i64 %n)
51+
; CHECK-NEXT: entry:
52+
; CHECK-NEXT: br label %loop
53+
54+
; CHECK: loop: ; preds = %end, %entry
55+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %end ], [ 0, %entry ]
56+
; CHECK-NEXT: %"sum'" = phi {{(fast )?}}double [ %1, %end ], [ 0.000000e+00, %entry ]
57+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
58+
; CHECK-NEXT: %g0 = getelementptr inbounds double, double* %x, i64 %iv
59+
; CHECK-NEXT: br label %body
60+
61+
; CHECK: body: ; preds = %body, %loop
62+
; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %body ], [ 0, %loop ]
63+
; CHECK-NEXT: %idx = phi i64 [ %nidx, %body ], [ 0, %loop ]
64+
; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1
65+
; CHECK-NEXT: %gep = getelementptr inbounds double, double* %g0, i64 %iv1
66+
; CHECK-NEXT: %ld = load double, double* %gep, align 8
67+
; CHECK-NEXT: %cmp = fcmp oeq double %ld, 0x400921FAFC8B007A
68+
; CHECK-NEXT: %int = fptoui double %ld to i64
69+
; CHECK-NEXT: %nidx = add nuw i64 %idx, %int
70+
; CHECK-NEXT: br i1 %cmp, label %body, label %end
71+
72+
; CHECK: end: ; preds = %body
73+
; CHECK-NEXT: %"gep2'ipg" = getelementptr inbounds double, double* %"x'", i64 %idx
74+
; CHECK-NEXT: %0 = load double, double* %"gep2'ipg"
75+
; CHECK-NEXT: %1 = fadd fast double %0, %"sum'"
76+
; CHECK-NEXT: %cmp2 = icmp ne i64 %iv.next, 10
77+
; CHECK-NEXT: br i1 %cmp2, label %loop, label %exit
78+
79+
; CHECK: exit: ; preds = %end
80+
; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0
81+
; CHECK-NEXT: ret { double } %2
82+
; CHECK-NEXT: }

enzyme/test/Enzyme/ForwardMode/divreduce.ll

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,56 +60,57 @@ declare double @__enzyme_fwddiff2(i8*, double*, double*, i64)
6060
!7 = !{!"any pointer", !4, i64 0}
6161

6262

63-
6463
; CHECK: define internal { double } @diffealldiv(double* nocapture readonly %A, double* nocapture %"A'", i64 %N, double %start, double %"start'")
6564
; CHECK-NEXT: entry:
66-
; CHECK-NEXT: br label %loop
65+
; CHECK-NEXT: br label %loop
6766

6867
; CHECK: loop: ; preds = %loop, %entry
6968
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %loop ], [ 0, %entry ]
7069
; CHECK-NEXT: %reduce = phi double [ %start, %entry ], [ %div, %loop ]
70+
; CHECK-NEXT: %"reduce'" = phi {{(fast )?}}double [ %"start'", %entry ], [ %5, %loop ]
7171
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
7272
; CHECK-NEXT: %"gep'ipg" = getelementptr inbounds double, double* %"A'", i64 %iv
7373
; CHECK-NEXT: %gep = getelementptr inbounds double, double* %A, i64 %iv
74-
; CHECK-NEXT: %ld = load double, double* %gep, align 8, !tbaa !2
74+
; CHECK-NEXT: %ld = load double, double* %gep, align 8, !tbaa !2
7575
; CHECK-NEXT: %0 = load double, double* %"gep'ipg"
7676
; CHECK-NEXT: %div = fdiv double %reduce, %ld
77-
; CHECK-NEXT: %1 = fmul fast double %reduce, %0
78-
; CHECK-NEXT: %2 = fsub fast double 0.000000e+00, %1
79-
; CHECK-NEXT: %3 = fmul fast double %ld, %ld
80-
; CHECK-NEXT: %4 = fdiv fast double %2, %3
77+
; CHECK-NEXT: %1 = fmul fast double %"reduce'", %ld
78+
; CHECK-NEXT: %2 = fmul fast double %reduce, %0
79+
; CHECK-NEXT: %3 = fsub fast double %1, %2
80+
; CHECK-NEXT: %4 = fmul fast double %ld, %ld
81+
; CHECK-NEXT: %5 = fdiv fast double %3, %4
8182
; CHECK-NEXT: %cmp = icmp eq i64 %iv.next, %N
8283
; CHECK-NEXT: br i1 %cmp, label %end, label %loop
8384

8485
; CHECK: end: ; preds = %loop
85-
; CHECK-NEXT: %5 = insertvalue { double } undef, double %4, 0
86-
; CHECK-NEXT: ret { double } %5
86+
; CHECK-NEXT: %6 = insertvalue { double } undef, double %5, 0
87+
; CHECK-NEXT: ret { double } %6
8788
; CHECK-NEXT: }
8889

8990

90-
91-
9291
; CHECK: define internal { double } @diffealldiv2(double* nocapture readonly %A, double* nocapture %"A'", i64 %N)
9392
; CHECK-NEXT: entry:
9493
; CHECK-NEXT: br label %loop
9594

9695
; CHECK: loop: ; preds = %loop, %entry
9796
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %loop ], [ 0, %entry ]
9897
; CHECK-NEXT: %reduce = phi double [ 2.000000e+00, %entry ], [ %div, %loop ]
98+
; CHECK-NEXT: %"reduce'" = phi {{(fast )?}}double [ 0.000000e+00, %entry ], [ %5, %loop ]
9999
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
100100
; CHECK-NEXT: %"gep'ipg" = getelementptr inbounds double, double* %"A'", i64 %iv
101101
; CHECK-NEXT: %gep = getelementptr inbounds double, double* %A, i64 %iv
102102
; CHECK-NEXT: %ld = load double, double* %gep, align 8, !tbaa !2
103103
; CHECK-NEXT: %0 = load double, double* %"gep'ipg"
104104
; CHECK-NEXT: %div = fdiv double %reduce, %ld
105-
; CHECK-NEXT: %1 = fmul fast double %reduce, %0
106-
; CHECK-NEXT: %2 = fsub fast double 0.000000e+00, %1
107-
; CHECK-NEXT: %3 = fmul fast double %ld, %ld
108-
; CHECK-NEXT: %4 = fdiv fast double %2, %3
105+
; CHECK-NEXT: %1 = fmul fast double %"reduce'", %ld
106+
; CHECK-NEXT: %2 = fmul fast double %reduce, %0
107+
; CHECK-NEXT: %3 = fsub fast double %1, %2
108+
; CHECK-NEXT: %4 = fmul fast double %ld, %ld
109+
; CHECK-NEXT: %5 = fdiv fast double %3, %4
109110
; CHECK-NEXT: %cmp = icmp eq i64 %iv.next, %N
110111
; CHECK-NEXT: br i1 %cmp, label %end, label %loop
111112

112113
; CHECK: end: ; preds = %loop
113-
; CHECK-NEXT: %5 = insertvalue { double } undef, double %4, 0
114-
; CHECK-NEXT: ret { double } %5
114+
; CHECK-NEXT: %6 = insertvalue { double } undef, double %5, 0
115+
; CHECK-NEXT: ret { double } %6
115116
; CHECK-NEXT: }

0 commit comments

Comments
 (0)