Skip to content

Commit d6a5b1a

Browse files
authored
Don't erase custom call for primal (rust-lang#247)
* Don't erase custom call for primal * Error on token type storage
1 parent 534c38f commit d6a5b1a

File tree

5 files changed

+34
-14
lines changed

5 files changed

+34
-14
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3942,7 +3942,19 @@ class AdjointGenerator
39423942
}
39433943
}
39443944

3945-
if (subretused) {
3945+
bool primalNeededInReverse;
3946+
3947+
if (gutils->knownRecomputeHeuristic.count(orig)) {
3948+
primalNeededInReverse = !gutils->knownRecomputeHeuristic[orig];
3949+
} else {
3950+
std::map<UsageKey, bool> Seen;
3951+
for (auto pair : gutils->knownRecomputeHeuristic)
3952+
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
3953+
primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
3954+
TR, gutils, orig, Mode, Seen, oldUnreachable);
3955+
}
3956+
3957+
if (subretused && primalNeededInReverse) {
39463958
if (normalReturn != newCall) {
39473959
assert(normalReturn->getType() == newCall->getType());
39483960
gutils->replaceAWithB(newCall, normalReturn);
@@ -3952,7 +3964,9 @@ class AdjointGenerator
39523964
normalReturn = gutils->cacheForReverse(BuilderZ, normalReturn,
39533965
getIndex(orig, CacheType::Self));
39543966
} else {
3955-
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
3967+
if (!orig->mayWriteToMemory() ||
3968+
Mode == DerivativeMode::ReverseModeGradient)
3969+
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
39563970
}
39573971
return;
39583972
}
@@ -4439,17 +4453,16 @@ class AdjointGenerator
44394453
return;
44404454
}
44414455
SmallVector<Value *, 1> iargs;
4442-
IRBuilder<> Builder2(call.getParent());
4443-
getReverseBuilder(Builder2);
4456+
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call));
44444457
for (size_t i = 0, end = orig->getNumArgOperands(); i < end; ++i) {
44454458
auto arg = orig->getArgOperand(i);
44464459
if (!gutils->isConstantValue(arg)) {
4447-
Value *ptrshadow = gutils->invertPointerM(arg, Builder2);
4460+
Value *ptrshadow = gutils->invertPointerM(arg, BuilderZ);
44484461
iargs.push_back(ptrshadow);
44494462
}
44504463
}
44514464
if (iargs.size()) {
4452-
Builder2.CreateCall(called, iargs);
4465+
BuilderZ.CreateCall(called, iargs);
44534466
}
44544467
return;
44554468
}

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,11 @@ static inline bool is_value_needed_in_reverse(
349349
if (!direct)
350350
continue;
351351

352+
if (inst->getType()->isTokenTy()) {
353+
llvm::errs() << " need " << *inst << " via " << *user << "\n";
354+
}
355+
assert(!inst->getType()->isTokenTy());
356+
352357
return seen[idx] = true;
353358
}
354359
return false;

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,9 @@ struct CacheAnalysis {
658658

659659
// We do not need uncacheable args for intrinsic functions. So skip such
660660
// callsites.
661-
if (isa<IntrinsicInst>(&inst)) {
662-
continue;
661+
if (auto II = dyn_cast<IntrinsicInst>(&inst)) {
662+
if (!II->getCalledFunction()->getName().startswith("llvm.julia"))
663+
continue;
663664
}
664665

665666
// For all other calls, we compute the uncacheable args for this

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,6 +1349,13 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
13491349
}
13501350
}
13511351

1352+
if (malloc->getType()->isTokenTy()) {
1353+
llvm::errs() << " oldFunc: " << *oldFunc << "\n";
1354+
llvm::errs() << " newFunc: " << *newFunc << "\n";
1355+
llvm::errs() << " malloc: " << *malloc << "\n";
1356+
}
1357+
assert(!malloc->getType()->isTokenTy());
1358+
13521359
if (tape) {
13531360
if (idx >= 0 && !tape->getType()->isStructTy()) {
13541361
llvm::errs() << "cacheForReverse incorrect tape type: " << *tape

enzyme/Enzyme/GradientUtils.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -700,12 +700,6 @@ class GradientUtils : public CacheUtility {
700700
bb.CreateCall(oldFunc->getParent()->getOrInsertFunction(
701701
"julia.write_barrier", FT),
702702
anti);
703-
if (mode != DerivativeMode::ReverseModeCombined) {
704-
EmitFailure("SplitGCAllocation", orig->getDebugLoc(), orig,
705-
"Not handling Julia shadow GC allocation in split mode ",
706-
*orig);
707-
return anti;
708-
}
709703
}
710704

711705
if (orig->getCalledFunction()->getName() == "swift_allocObject") {

0 commit comments

Comments
 (0)