Skip to content

Commit 189a8ff

Browse files
committed
Fix indexing and c tests pass
1 parent 13a79c4 commit 189a8ff

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
164164

165165
// Case 2: The correct exiting block terminator unconditionally branches a different block, change to a conditional branch depending on if we are the first iteration
166166
} else if (succ.size() == 1) {
167+
lc.latchMerge->getTerminator()->eraseFromParent();
168+
mergeBuilder.SetInsertPoint(lc.latchMerge);
169+
170+
assert(mergeBuilder.GetInsertBlock()->size() == 0 || !isa<BranchInst>(mergeBuilder.GetInsertBlock()->back()));
167171

168172
// If first iteration, branch to the exiting block, otherwise the backlatch
169173
mergeBuilder.CreateCondBr(firstiter, succ[0], reverseBlocks[backlatch]);
@@ -187,6 +191,8 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
187191

188192
lc.latchMerge->getTerminator()->eraseFromParent();
189193
mergeBuilder.SetInsertPoint(lc.latchMerge);
194+
195+
assert(mergeBuilder.GetInsertBlock()->size() == 0 || !isa<BranchInst>(mergeBuilder.GetInsertBlock()->back()));
190196
mergeBuilder.CreateCondBr(firstiter, splitBlock, reverseBlocks[backlatch]);
191197

192198
}
@@ -858,6 +864,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
858864

859865
if (targetToPreds.size() == 1) {
860866
if (replacePHIs == nullptr) {
867+
assert(BuilderM.GetInsertBlock()->size() == 0 || !isa<BranchInst>(BuilderM.GetInsertBlock()->back()));
861868
BuilderM.CreateBr( targetToPreds.begin()->first );
862869
} else {
863870
for (auto pair : *replacePHIs) {
@@ -962,6 +969,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
962969
Value* phi = lookupValueFromCache(BuilderM, ctx, cache);
963970

964971
if (replacePHIs == nullptr) {
972+
assert(BuilderM.GetInsertBlock()->size() == 0 || !isa<BranchInst>(BuilderM.GetInsertBlock()->back()));
965973
BuilderM.CreateCondBr(phi, *done[std::make_pair(block, branch->getSuccessor(0))].begin(), *done[std::make_pair(block, branch->getSuccessor(1))].begin());
966974
} else {
967975
for (auto pair : *replacePHIs) {
@@ -1076,6 +1084,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
10761084

10771085
if (replacePHIs == nullptr) {
10781086
if (targetToPreds.size() == 2) {
1087+
assert(BuilderM.GetInsertBlock()->size() == 0 || !isa<BranchInst>(BuilderM.GetInsertBlock()->back()));
10791088
BuilderM.CreateCondBr(which, /*true*/targets[1], /*false*/targets[0]);
10801089
} else {
10811090
auto swit = BuilderM.CreateSwitch(which, targets.back(), targets.size()-1);

enzyme/Enzyme/GradientUtils.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -966,12 +966,14 @@ class GradientUtils {
966966
llvm::errs() << "starting outermost ph at " << allocationPreheaders[i]->getName() << "|ctx=" << ctx->getName() <<"\n";
967967
sublimits.push_back(std::make_pair(size, lims));
968968
size = nullptr;
969+
lims.clear();
969970
}
970971
}
971972

972973
if (size != nullptr) {
973974
llvm::errs() << "starting final outermost ph at " << allocationPreheaders[contexts.size()-1]->getName()<<"|ctx=" << ctx->getName() << "\n";
974975
sublimits.push_back(std::make_pair(size, lims));
976+
lims.clear();
975977
}
976978
return sublimits;
977979
}
@@ -1118,6 +1120,7 @@ class GradientUtils {
11181120
indices.push_back(idx.var);
11191121
available[idx.var] = idx.var;
11201122
}
1123+
llvm::errs() << "W sl idx=" << i << " " << *idx.var << " header=" << idx.header->getName() << "\n";
11211124

11221125
Value* lim = unwrapM(riter->second, BuilderM, available, /*lookupIfAble*/true);
11231126
assert(lim);
@@ -1129,9 +1132,11 @@ class GradientUtils {
11291132
}
11301133

11311134
if (indices.size() > 0) {
1135+
llvm::errs() << "sl idx=" << i << " " << *indices[0] << "\n";
11321136
Value* idx = indices[0];
1133-
for(unsigned i=1; i<indices.size(); i++) {
1134-
idx = BuilderM.CreateNUWAdd(idx, BuilderM.CreateNUWMul(indices[i], limits[i-1]));
1137+
for(unsigned ind=1; ind<indices.size(); ind++) {
1138+
llvm::errs() << "sl idx=" << i << " " << *indices[ind] << "\n";
1139+
idx = BuilderM.CreateNUWAdd(idx, BuilderM.CreateNUWMul(indices[ind], limits[ind-1]));
11351140
}
11361141
next = BuilderM.CreateGEP(next, {idx});
11371142
}

0 commit comments

Comments
 (0)