Skip to content

Commit 13a79c4

Browse files
committed
fixed dyn loop bug
1 parent 94428cd commit 13a79c4

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

enzyme/Enzyme/GradientUtils.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -903,12 +903,9 @@ class GradientUtils {
903903
if (!getContext(blk, idx)) {
904904
break;
905905
}
906+
llvm::errs() << " adding to contexts: " << idx.header->getName() << " starting ctx=" << ctx->getName() << "\n";
906907
contexts.emplace_back(idx);
907-
if (idx.parent) {
908-
blk = idx.parent->getHeader();
909-
} else {
910-
blk = nullptr;
911-
}
908+
blk = idx.preheader;
912909
}
913910

914911
std::vector<BasicBlock*> allocationPreheaders(contexts.size(), nullptr);
@@ -964,14 +961,16 @@ class GradientUtils {
964961
size = allocationBuilder.CreateNUWMul(size, limits[i]);
965962
}
966963

964+
llvm::errs() << "considering ctx " << ctx->getName() << " alph=" << allocationPreheaders[i]->getName() << " ctxheader=" << contexts[i].header->getName() << "\n";
967965
if (contexts[i].dynamic) {
966+
llvm::errs() << "starting outermost ph at " << allocationPreheaders[i]->getName() << "|ctx=" << ctx->getName() <<"\n";
968967
sublimits.push_back(std::make_pair(size, lims));
969968
size = nullptr;
970-
break;
971969
}
972970
}
973971

974972
if (size != nullptr) {
973+
llvm::errs() << "starting final outermost ph at " << allocationPreheaders[contexts.size()-1]->getName()<<"|ctx=" << ctx->getName() << "\n";
975974
sublimits.push_back(std::make_pair(size, lims));
976975
}
977976
return sublimits;
@@ -994,13 +993,21 @@ class GradientUtils {
994993
IRBuilder<> entryBuilder(inversionAllocs);
995994
entryBuilder.setFastMathFlags(getFast());
996995
AllocaInst* alloc = entryBuilder.CreateAlloca(types.back(), nullptr, name+"_cache");
996+
llvm::errs() << "alloc: "<< *alloc << "\n";
997997

998998
Type *BPTy = Type::getInt8PtrTy(ctx->getContext());
999999
auto realloc = newFunc->getParent()->getOrInsertFunction("realloc", BPTy, BPTy, Type::getInt64Ty(ctx->getContext()));
10001000

10011001
Value* storeInto = alloc;
1002+
ValueToValueMapTy antimap;
1003+
10021004
for(int i=sublimits.size()-1; i>=0; i--) {
10031005
const auto& containedloops = sublimits[i].second;
1006+
for(auto riter = containedloops.rbegin(), rend = containedloops.rend(); riter != rend; riter++) {
1007+
const auto& idx = riter->first;
1008+
antimap[idx.var] = idx.antivar;
1009+
}
1010+
10041011
Value* size = sublimits[i].first;
10051012
Type* myType = types[i];
10061013

@@ -1027,7 +1034,9 @@ class GradientUtils {
10271034
//allocationBuilder.GetInsertBlock()->getInstList().push_back(cast<Instruction>(allocation));
10281035
//cast<Instruction>(firstallocation)->moveBefore(allocationBuilder.GetInsertBlock()->getTerminator());
10291036
//mallocs.push_back(firstallocation);
1030-
} else {
1037+
} else {
1038+
llvm::errs() << "storeInto: " << *storeInto << "\n";
1039+
llvm::errs() << "myType: " << *myType << "\n";
10311040
allocationBuilder.CreateStore(ConstantPointerNull::get(PointerType::getUnqual(myType)), storeInto);
10321041

10331042
IRBuilder <> build(containedloops.back().first.header->getFirstNonPHI());
@@ -1060,7 +1069,7 @@ class GradientUtils {
10601069
tbuild.SetInsertPoint(tbuild.GetInsertBlock()->getFirstNonPHI());
10611070
}
10621071

1063-
auto ci = cast<CallInst>(CallInst::CreateFree(tbuild.CreatePointerCast(tbuild.CreateLoad(storeInto), Type::getInt8PtrTy(ctx->getContext())), tbuild.GetInsertBlock()));
1072+
auto ci = cast<CallInst>(CallInst::CreateFree(tbuild.CreatePointerCast(tbuild.CreateLoad(unwrapM(storeInto, tbuild, antimap, /*lookup*/false)), Type::getInt8PtrTy(ctx->getContext())), tbuild.GetInsertBlock()));
10641073
ci->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull);
10651074
if (ci->getParent()==nullptr) {
10661075
tbuild.Insert(ci);
@@ -1072,9 +1081,9 @@ class GradientUtils {
10721081
IRBuilder <>v(&sublimits[i-1].second.back().first.preheader->back());
10731082
//TODO
10741083
if (!sublimits[i].second.back().first.dynamic) {
1075-
storeInto = v.CreateLoad(v.CreateGEP(v.CreateLoad(storeInto), sublimits[i].second.back().first.var));
1084+
storeInto = v.CreateGEP(v.CreateLoad(storeInto), sublimits[i].second.back().first.var);
10761085
} else {
1077-
storeInto = v.CreateLoad(v.CreateGEP(v.CreateLoad(storeInto), sublimits[i].second.back().first.var));
1086+
storeInto = v.CreateGEP(v.CreateLoad(storeInto), sublimits[i].second.back().first.var);
10781087
}
10791088
}
10801089
}

enzyme/functional_tests_c/setup.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/bin/bash
22

33
# NOTE(TFK): Uncomment for local testing.
4-
export CLANG_BIN_PATH=./../../llvm/build/bin
5-
export ENZYME_PLUGIN=./../build/Enzyme/LLVMEnzyme-7.so
4+
export CLANG_BIN_PATH=./../../build-dbg/bin
5+
export ENZYME_PLUGIN=./../mkdebug/Enzyme/LLVMEnzyme-7.so
66

77
mkdir -p build
88
$@

0 commit comments

Comments
 (0)