Skip to content

Commit 26d31e0

Browse files
wsmosesUbuntu
andauthored
Fix remat debug info (rust-lang#529)
* Fix remat debug info * mpi fix * Fix Julia * Preserve for reverse bundles * Fix format Co-authored-by: Ubuntu <[email protected]>
1 parent 79efad6 commit 26d31e0

File tree

6 files changed

+104
-24
lines changed

6 files changed

+104
-24
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5362,12 +5362,12 @@ class AdjointGenerator
53625362
tysize = C;
53635363
} else {
53645364
#if LLVM_VERSION_MAJOR > 7
5365-
len_arg = Builder2.CreateLoad(
5365+
tysize = Builder2.CreateLoad(
53665366
Type::getInt8PtrTy(call.getContext()),
5367-
getMPIMemberPtr<MPI_Elem::Count>(Builder2, helper));
5367+
getMPIMemberPtr<MPI_Elem::DataType>(Builder2, helper));
53685368
#else
5369-
len_arg = Builder2.CreateLoad(
5370-
getMPIMemberPtr<MPI_Elem::Count>(Builder2, helper));
5369+
tysize = Builder2.CreateLoad(
5370+
getMPIMemberPtr<MPI_Elem::DataType>(Builder2, helper));
53715371
#endif
53725372
}
53735373

@@ -5386,6 +5386,7 @@ class AdjointGenerator
53865386

53875387
assert(shouldFree());
53885388

5389+
assert(tysize);
53895390
tysize = MPI_TYPE_SIZE(tysize, Builder2, call.getType());
53905391

53915392
Value *args[] = {/*req*/ req,
@@ -5633,10 +5634,10 @@ class AdjointGenerator
56335634

56345635
BasicBlock *currentBlock = Builder2.GetInsertBlock();
56355636
BasicBlock *nonnullBlock = gutils->addReverseBlock(
5636-
currentBlock, currentBlock->getName() + "_nonnull",
5637-
gutils->newFunc);
5637+
currentBlock, currentBlock->getName() + "_nonnull");
56385638
BasicBlock *endBlock = gutils->addReverseBlock(
5639-
nonnullBlock, currentBlock->getName() + "_end", gutils->newFunc);
5639+
nonnullBlock, currentBlock->getName() + "_end",
5640+
/*fork*/ true, /*push*/ false);
56405641

56415642
Builder2.CreateCondBr(isNull, endBlock, nonnullBlock);
56425643
Builder2.SetInsertPoint(nonnullBlock);
@@ -5680,7 +5681,13 @@ class AdjointGenerator
56805681
Attribute::AlwaysInline);
56815682
#endif
56825683
Builder2.CreateBr(endBlock);
5683-
5684+
{
5685+
auto found = gutils->reverseBlockToPrimal.find(endBlock);
5686+
assert(found != gutils->reverseBlockToPrimal.end());
5687+
std::vector<BasicBlock *> &vec = gutils->reverseBlocks[found->second];
5688+
assert(vec.size());
5689+
vec.push_back(endBlock);
5690+
}
56845691
Builder2.SetInsertPoint(endBlock);
56855692
} else if (Mode == DerivativeMode::ForwardMode) {
56865693
IRBuilder<> Builder2(&call);
@@ -5778,13 +5785,14 @@ class AdjointGenerator
57785785

57795786
BasicBlock *currentBlock = Builder2.GetInsertBlock();
57805787
BasicBlock *loopBlock = gutils->addReverseBlock(
5781-
currentBlock, currentBlock->getName() + "_loop", gutils->newFunc);
5788+
currentBlock, currentBlock->getName() + "_loop");
57825789
BasicBlock *nonnullBlock = gutils->addReverseBlock(
5783-
loopBlock, currentBlock->getName() + "_nonnull", gutils->newFunc);
5790+
loopBlock, currentBlock->getName() + "_nonnull");
57845791
BasicBlock *eloopBlock = gutils->addReverseBlock(
5785-
nonnullBlock, currentBlock->getName() + "_eloop", gutils->newFunc);
5792+
nonnullBlock, currentBlock->getName() + "_eloop");
57865793
BasicBlock *endBlock = gutils->addReverseBlock(
5787-
eloopBlock, currentBlock->getName() + "_end", gutils->newFunc);
5794+
eloopBlock, currentBlock->getName() + "_end",
5795+
/*fork*/ true, /*push*/ false);
57885796

57895797
Builder2.CreateCondBr(
57905798
Builder2.CreateICmpNE(count,
@@ -5872,6 +5880,13 @@ class AdjointGenerator
58725880
Builder2.SetInsertPoint(eloopBlock);
58735881
Builder2.CreateCondBr(Builder2.CreateICmpEQ(inc, count), endBlock,
58745882
loopBlock);
5883+
{
5884+
auto found = gutils->reverseBlockToPrimal.find(endBlock);
5885+
assert(found != gutils->reverseBlockToPrimal.end());
5886+
std::vector<BasicBlock *> &vec = gutils->reverseBlocks[found->second];
5887+
assert(vec.size());
5888+
vec.push_back(endBlock);
5889+
}
58755890
Builder2.SetInsertPoint(endBlock);
58765891
if (shouldFree()) {
58775892
auto ci = cast<CallInst>(CallInst::CreateFree(
@@ -9596,15 +9611,28 @@ class AdjointGenerator
95969611
return;
95979612
}
95989613

9614+
auto ifound = gutils->invertedPointers.find(orig);
9615+
assert(ifound != gutils->invertedPointers.end());
9616+
9617+
auto placeholder = cast<PHINode>(&*ifound->second);
9618+
9619+
bool needShadow = (Mode == DerivativeMode::ForwardMode ||
9620+
Mode == DerivativeMode::ForwardModeSplit)
9621+
? true
9622+
: is_value_needed_in_reverse<ValueType::Shadow>(
9623+
TR, gutils, orig, Mode, oldUnreachable);
9624+
if (!needShadow) {
9625+
gutils->invertedPointers.erase(ifound);
9626+
gutils->erase(placeholder);
9627+
eraseIfUnused(*orig);
9628+
return;
9629+
}
9630+
95999631
Value *ptrshadow =
96009632
gutils->invertPointerM(call.getArgOperand(0), BuilderZ);
96019633
Value *val =
96029634
BuilderZ.CreateCall(called, std::vector<Value *>({ptrshadow}));
96039635

9604-
auto ifound = gutils->invertedPointers.find(orig);
9605-
assert(ifound != gutils->invertedPointers.end());
9606-
9607-
auto placeholder = cast<PHINode>(&*ifound->second);
96089636
gutils->replaceAWithB(placeholder, val);
96099637
gutils->erase(placeholder);
96109638
eraseIfUnused(*orig);

enzyme/Enzyme/CApi.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,4 +555,18 @@ void EnzymeAddAttributorLegacyPass(LLVMPassManagerRef PM) {
555555
unwrap(PM)->add(createAttributorLegacyPass());
556556
}
557557
#endif
558+
LLVMMetadataRef EnzymeMakeNonConstTBAA(LLVMMetadataRef MD) {
559+
auto M = cast<MDNode>(unwrap(MD));
560+
if (M->getNumOperands() != 4)
561+
return MD;
562+
auto CAM = dyn_cast<ConstantAsMetadata>(M->getOperand(3));
563+
if (!CAM)
564+
return MD;
565+
if (!CAM->getValue()->isOneValue())
566+
return MD;
567+
SmallVector<Metadata *, 4> MDs(M->operands());
568+
MDs[3] =
569+
ConstantAsMetadata::get(ConstantInt::get(CAM->getValue()->getType(), 0));
570+
return wrap(MDNode::get(M->getContext(), MDs));
571+
}
558572
}

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,17 @@ static inline bool is_value_needed_in_reverse(
306306
}
307307

308308
if (auto CI = dyn_cast<CallInst>(user)) {
309+
{
310+
SmallVector<OperandBundleDef, 2> OrigDefs;
311+
CI->getOperandBundlesAsDefs(OrigDefs);
312+
SmallVector<OperandBundleDef, 2> Defs;
313+
for (auto bund : OrigDefs) {
314+
for (auto inp : bund.inputs()) {
315+
if (inp == inst)
316+
return seen[idx] = true;
317+
}
318+
}
319+
}
309320
if (auto F = getFunctionFromCall(const_cast<CallInst *>(CI))) {
310321
StringRef funcName = F->getName();
311322
if (F->hasFnAttribute("enzyme_math"))
@@ -331,9 +342,8 @@ static inline bool is_value_needed_in_reverse(
331342
if (inst == CI->getArgOperand(6))
332343
return seen[idx] = true;
333344
// Need shadow buffer in reverse pass or forward mode
334-
if (mode != DerivativeMode::ReverseModePrimal)
335-
if (inst == CI->getArgOperand(0))
336-
return seen[idx] = true;
345+
if (inst == CI->getArgOperand(0))
346+
return seen[idx] = true;
337347
continue;
338348
}
339349

@@ -342,7 +352,7 @@ static inline bool is_value_needed_in_reverse(
342352
if (funcName == "MPI_Wait" || funcName == "PMPI_Wait") {
343353
if (gutils->isConstantInstruction(const_cast<Instruction *>(user)))
344354
continue;
345-
// Need shadow request in forward pass
355+
// Need shadow request in forward pass only
346356
if (mode != DerivativeMode::ReverseModeGradient)
347357
if (inst == CI->getArgOperand(0))
348358
return seen[idx] = true;

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2444,6 +2444,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
24442444
ts->setVolatile(SI->isVolatile());
24452445
ts->setOrdering(SI->getOrdering());
24462446
ts->setSyncScopeID(SI->getSyncScopeID());
2447+
ts->setDebugLoc(getNewFromOriginal(I.getDebugLoc()));
24472448
} else if (auto CI = dyn_cast<CallInst>(&I)) {
24482449
Function *called = getFunctionFromCall(CI);
24492450
assert(called);
@@ -2469,6 +2470,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
24692470
cal->setAttributes(CI->getAttributes());
24702471
cal->setCallingConv(CI->getCallingConv());
24712472
cal->setTailCallKind(CI->getTailCallKind());
2473+
cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc()));
24722474
} else {
24732475
assert(isDeallocationFunction(*called, TLI));
24742476
continue;
@@ -2508,6 +2510,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
25082510
#else
25092511
replacement->setAlignment(Alignment);
25102512
#endif
2513+
replacement->setDebugLoc(getNewFromOriginal(I.getDebugLoc()));
25112514
storeInstructionInCache(lctx, NB, replacement, cache);
25122515
} else if (auto CI = dyn_cast<CallInst>(&I)) {
25132516
SmallVector<Value *, 2> args;
@@ -2530,6 +2533,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
25302533
cal->setAttributes(CI->getAttributes());
25312534
cal->setCallingConv(CI->getCallingConv());
25322535
cal->setTailCallKind(CI->getTailCallKind());
2536+
cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc()));
25332537
storeInstructionInCache(lctx, NB, cal, cache);
25342538
} else {
25352539
llvm::errs() << " realloc: " << I << "\n";
@@ -2636,6 +2640,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
26362640
cal->setAttributes(MS->getAttributes());
26372641
cal->setCallingConv(MS->getCallingConv());
26382642
cal->setTailCallKind(MS->getTailCallKind());
2643+
cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc()));
26392644
}
26402645
} else if (auto CI = dyn_cast<CallInst>(&I)) {
26412646
Function *called = getFunctionFromCall(CI);
@@ -2664,6 +2669,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
26642669
cal->setAttributes(CI->getAttributes());
26652670
cal->setCallingConv(CI->getCallingConv());
26662671
cal->setTailCallKind(CI->getTailCallKind());
2672+
cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc()));
26672673
}
26682674
} else {
26692675
assert(isDeallocationFunction(*called, TLI));
@@ -2731,7 +2737,8 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
27312737
orig->getCallingConv());
27322738
cast<CallInst>(anti)->setTailCallKind(
27332739
orig->getTailCallKind());
2734-
cast<CallInst>(anti)->setDebugLoc(dbgLoc);
2740+
cast<CallInst>(anti)->setDebugLoc(
2741+
getNewFromOriginal(I.getDebugLoc()));
27352742

27362743
#if LLVM_VERSION_MAJOR >= 14
27372744
cast<CallInst>(anti)->addAttributeAtIndex(
@@ -2758,6 +2765,8 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
27582765
#else
27592766
replacement->setAlignment(Alignment);
27602767
#endif
2768+
replacement->setDebugLoc(
2769+
getNewFromOriginal(I.getDebugLoc()));
27612770
replaceAWithB(cast<Instruction>(anti), replacement);
27622771
erase(cast<Instruction>(anti));
27632772
anti = replacement;
@@ -6001,6 +6010,11 @@ fast:;
60016010
assert(branch->getCondition()->getType() == T);
60026011

60036012
if (replacePHIs == nullptr) {
6013+
if (!(BuilderM.GetInsertBlock()->size() == 0 ||
6014+
!isa<BranchInst>(BuilderM.GetInsertBlock()->back()))) {
6015+
llvm::errs() << "newFunc : " << *newFunc << "\n";
6016+
llvm::errs() << "blk : " << *BuilderM.GetInsertBlock() << "\n";
6017+
}
60046018
assert(BuilderM.GetInsertBlock()->size() == 0 ||
60056019
!isa<BranchInst>(BuilderM.GetInsertBlock()->back()));
60066020
BuilderM.CreateCondBr(

enzyme/Enzyme/GradientUtils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ class GradientUtils : public CacheUtility {
887887

888888
public:
889889
BasicBlock *addReverseBlock(BasicBlock *currentBlock, Twine name,
890-
bool forkCache = true) {
890+
bool forkCache = true, bool push = true) {
891891
assert(reverseBlocks.size());
892892
auto found = reverseBlockToPrimal.find(currentBlock);
893893
assert(found != reverseBlockToPrimal.end());
@@ -899,7 +899,8 @@ class GradientUtils : public CacheUtility {
899899
BasicBlock *rev =
900900
BasicBlock::Create(currentBlock->getContext(), name, newFunc);
901901
rev->moveAfter(currentBlock);
902-
vec.push_back(rev);
902+
if (push)
903+
vec.push_back(rev);
903904
reverseBlockToPrimal[rev] = found->second;
904905
if (forkCache) {
905906
for (auto pair : unwrap_cache[currentBlock])

enzyme/Enzyme/Utils.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,13 +455,26 @@ llvm::Function *getOrInsertDifferentialMPI_Wait(llvm::Module &M,
455455
Value *d_req = buff + 7;
456456
d_req->setName("d_req");
457457

458+
bool pmpi = true;
458459
auto isendfn = M.getFunction("PMPI_Isend");
459-
if (!isendfn)
460+
if (!isendfn) {
460461
isendfn = M.getFunction("MPI_Isend");
462+
pmpi = false;
463+
}
461464
assert(isendfn);
462465
auto irecvfn = M.getFunction("PMPI_Irecv");
463466
if (!irecvfn)
464467
irecvfn = M.getFunction("MPI_Irecv");
468+
if (!irecvfn) {
469+
FunctionType *FuT = isendfn->getFunctionType();
470+
std::string name = pmpi ? "PMPI_Irecv" : "MPI_Irecv";
471+
#if LLVM_VERSION_MAJOR >= 9
472+
irecvfn = cast<Function>(M.getOrInsertFunction(name, FuT).getCallee());
473+
474+
#else
475+
irecvfn = cast<Function>(M.getOrInsertFunction(name, FuT));
476+
#endif
477+
}
465478
assert(irecvfn);
466479

467480
IRBuilder<> B(entry);

0 commit comments

Comments
 (0)