Skip to content

Commit 7eaa0cc

Browse files
committed
fix segfault
1 parent ffbe943 commit 7eaa0cc

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,19 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
398398
if (gutils->newFunc->hasFnAttribute(Attribute::OptimizeNone))
399399
gutils->newFunc->removeFnAttr(Attribute::OptimizeNone);
400400

401+
if (auto bytes = gutils->newFunc->getDereferenceableBytes(llvm::AttributeList::ReturnIndex)) {
402+
AttrBuilder ab;
403+
ab.addDereferenceableAttr(bytes);
404+
gutils->newFunc->removeAttributes(llvm::AttributeList::ReturnIndex, ab);
405+
}
406+
407+
if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex, llvm::Attribute::NoAlias)) {
408+
gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex, llvm::Attribute::NoAlias);
409+
}
410+
if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex, llvm::Attribute::ZExt)) {
411+
gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex, llvm::Attribute::ZExt);
412+
}
413+
401414
if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) {
402415
llvm::errs() << *gutils->oldFunc << "\n";
403416
llvm::errs() << *gutils->newFunc << "\n";

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
955955

956956
Value* phi = lookupValueFromCache(BuilderM, ctx, cache);
957957

958+
958959
if (replacePHIs == nullptr) {
959960
SwitchInst* swtch = BuilderM.CreateSwitch(phi, *done[std::make_pair(block, si->getDefaultDest())].begin());
960961
for (auto switchcase : si->cases()) {
@@ -974,6 +975,12 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
974975
}
975976
val = BuilderM.CreateNot(val);
976977
}
978+
if (&*BuilderM.GetInsertPoint() == pair.second) {
979+
if (pair.second->getNextNode())
980+
BuilderM.SetInsertPoint(pair.second->getNextNode());
981+
else
982+
BuilderM.SetInsertPoint(pair.second->getParent());
983+
}
977984
pair.second->replaceAllUsesWith(val);
978985
pair.second->eraseFromParent();
979986
}
@@ -1049,6 +1056,12 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
10491056
} else {
10501057
val = BuilderM.CreateICmpEQ(ConstantInt::get(T, i), which);
10511058
}
1059+
if (&*BuilderM.GetInsertPoint() == found->second) {
1060+
if (found->second->getNextNode())
1061+
BuilderM.SetInsertPoint(found->second->getNextNode());
1062+
else
1063+
BuilderM.SetInsertPoint(found->second->getParent());
1064+
}
10521065
found->second->replaceAllUsesWith(val);
10531066
found->second->eraseFromParent();
10541067
}

0 commit comments

Comments
 (0)