Skip to content

Commit 3e5c6c0

Browse files
authored
Handle rematerialize successors outside loop (rust-lang#603)
1 parent 11deb1f commit 3e5c6c0

File tree

2 files changed

+100
-2
lines changed

2 files changed

+100
-2
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2779,8 +2779,11 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
27792779
// the phi unwrapping
27802780
if (!notForAnalysis.count(B) &&
27812781
NB.GetInsertBlock() != origToNewForward[B]) {
2782-
for (auto S : successors(B)) {
2783-
S = origToNewForward[S];
2782+
for (auto S0 : successors(B)) {
2783+
if (!origToNewForward.count(S0))
2784+
continue;
2785+
auto S = origToNewForward[S0];
2786+
assert(S);
27842787
for (auto I = S->begin(), E = S->end(); I != E; ++I) {
27852788
PHINode *orig = dyn_cast<PHINode>(&*I);
27862789
if (orig == nullptr)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s
2+
3+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
4+
target triple = "x86_64-unknown-linux-gnu"
5+
6+
@enzyme_const = dso_local local_unnamed_addr global i32 0, align 4
7+
8+
declare nonnull i8* @malloc(i64)
9+
10+
declare void @free(i8*)
11+
12+
define double @_Z15integrate_imagedi(double %arg, i32 %arg1) {
13+
bb:
14+
%i7 = zext i32 %arg1 to i64
15+
%i4 = mul i64 %i7, 8
16+
br label %bb8
17+
18+
bb8: ; preds = %bb19, %bb
19+
%i9 = phi double [ 0.000000e+00, %bb ], [ %i13, %bb12 ]
20+
%i10 = tail call noalias nonnull i8* @malloc(i64 %i4)
21+
%i11 = bitcast i8* %i10 to double*
22+
br label %bb14
23+
24+
bb14: ; preds = %bb14, %bb8
25+
%i15 = phi i64 [ %i17, %bb14 ], [ 0, %bb8 ]
26+
%i16 = getelementptr inbounds double, double* %i11, i64 %i15
27+
store double %arg, double* %i16, align 8
28+
%i17 = add nuw nsw i64 %i15, 1
29+
%i18 = icmp eq i64 %i17, %i7
30+
br i1 %i18, label %bb12, label %bb14
31+
32+
bb12: ; preds = %bb14
33+
%i13 = load double, double* %i11, align 8
34+
tail call void @free(i8* nonnull %i10)
35+
%i21 = fsub double %i13, %i9
36+
%i22 = fcmp ogt double %i21, 1.000000e-04
37+
br i1 %i22, label %bb8, label %bb23
38+
39+
bb23: ; preds = %bb19
40+
ret double %i13
41+
}
42+
43+
define dso_local double @_Z3dondd(double %arg, double %arg1) {
44+
bb:
45+
%i = load i32, i32* @enzyme_const, align 4
46+
%i2 = tail call double (double (double, i32)*, ...) @_Z17__enzyme_autodiffPFddiEz(double (double, i32)* nonnull @_Z15integrate_imagedi, i32 %i, double %arg, i32 %i, i32 10)
47+
ret double %i2
48+
}
49+
50+
declare dso_local double @_Z17__enzyme_autodiffPFddiEz(double (double, i32)*, ...)
51+
52+
; CHECK: define internal void @diffe_Z15integrate_imagedi(double %arg, i32 %arg1, double %differeturn)
53+
; CHECK-NEXT: bb:
54+
; CHECK-NEXT: %i7 = zext i32 %arg1 to i64
55+
; CHECK-NEXT: %i4 = mul {{(nuw nsw )?}}i64 %i7, 8
56+
; CHECK-NEXT: %0 = add {{(nsw )?}}i64 %i7, -1
57+
; CHECK-NEXT: br label %bb8
58+
59+
; CHECK: bb8: ; preds = %bb12, %bb
60+
; CHECK-NEXT: %i9 = phi double [ 0.000000e+00, %bb ], [ %i13, %bb12 ]
61+
; CHECK-NEXT: %i10 = tail call noalias nonnull i8* @malloc(i64 %i4)
62+
; CHECK-NEXT: %i11 = bitcast i8* %i10 to double*
63+
; CHECK-NEXT: br label %bb14
64+
65+
; CHECK: bb14: ; preds = %bb14, %bb8
66+
; CHECK-NEXT: %iv1 = phi i64 [ %iv.next2, %bb14 ], [ 0, %bb8 ]
67+
; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1
68+
; CHECK-NEXT: %i16 = getelementptr inbounds double, double* %i11, i64 %iv1
69+
; CHECK-NEXT: store double %arg, double* %i16, align 8
70+
; CHECK-NEXT: %i18 = icmp eq i64 %iv.next2, %i7
71+
; CHECK-NEXT: br i1 %i18, label %bb12, label %bb14
72+
73+
; CHECK: bb12: ; preds = %bb14
74+
; CHECK-NEXT: %i13 = load double, double* %i11, align 8, !invariant.group !0
75+
; CHECK-NEXT: tail call void @free(i8* nonnull %i10)
76+
; CHECK-NEXT: %i21 = fsub double %i13, %i9
77+
; CHECK-NEXT: %i22 = fcmp ogt double %i21, 1.000000e-04
78+
; CHECK-NEXT: br i1 %i22, label %bb8, label %[[remat_bb8_bb8:.+]]
79+
80+
; CHECK: [[remat_bb8_bb8]]: ; preds = %bb12
81+
; CHECK: %remat_i10 = tail call noalias nonnull i8* @malloc(i64 %i4)
82+
; CHECK-NEXT: br label %remat_bb8_bb14
83+
84+
; CHECK: remat_bb8_bb14:
85+
; CHECK-NEXT: %fiv = phi i64 [ %[[i1:.+]], %remat_bb8_bb14 ], [ 0, %[[remat_bb8_bb8]] ]
86+
; CHECK-NEXT: %[[i1]] = add i64 %fiv, 1
87+
; CHECK-NEXT: %i11_unwrap = bitcast i8* %remat_i10 to double*
88+
; CHECK-NEXT: %i16_unwrap = getelementptr inbounds double, double* %i11_unwrap, i64 %fiv
89+
; CHECK-NEXT: store double %arg, double* %i16_unwrap, align 8
90+
; CHECK-NEXT: %i18_unwrap = icmp eq i64 %[[i1]], %i7
91+
; CHECK-NEXT: br i1 %i18_unwrap, label %remat_bb8_bb12_phimerge, label %remat_bb8_bb14
92+
93+
; CHECK: remat_bb8_bb12_phimerge: ; preds = %remat_bb8_bb14
94+
; CHECK: tail call void @free(i8* nonnull %remat_i10)
95+

0 commit comments

Comments
 (0)