Skip to content

Commit 0666132

Browse files
authored
Permit limited circumstance recomputation of phi headers (rust-lang#524)
* Permit limited circumstance recomputation of phi headers * Handle placeholder unwraps * Fix MD build
1 parent 6382c2a commit 0666132

File tree

9 files changed

+471
-149
lines changed

9 files changed

+471
-149
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5463,10 +5463,6 @@ class AdjointGenerator
54635463
&call, {ValueType::Shadow, ValueType::Shadow}, Builder2,
54645464
/*lookup*/ false);
54655465

5466-
Type *types[sizeof(args) / sizeof(*args)];
5467-
for (size_t i = 0; i < sizeof(args) / sizeof(*args); i++)
5468-
types[i] = args[i]->getType();
5469-
54705466
#if LLVM_VERSION_MAJOR >= 11
54715467
auto callval = call.getCalledOperand();
54725468
#else

enzyme/Enzyme/CacheUtility.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,8 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
10201020
IRBuilder<> v(&sublimits[i - 1].second.back().first.preheader->back());
10211021

10221022
Value *idx = computeIndexOfChunk(
1023-
/*inForwardPass*/ true, v, containedloops);
1023+
/*inForwardPass*/ true, v, containedloops,
1024+
/*available*/ ValueToValueMapTy());
10241025

10251026
#if LLVM_VERSION_MAJOR > 7
10261027
storeInto = v.CreateLoad(
@@ -1046,17 +1047,14 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
10461047

10471048
Value *CacheUtility::computeIndexOfChunk(
10481049
bool inForwardPass, IRBuilder<> &v,
1049-
const std::vector<std::pair<LoopContext, llvm::Value *>> &containedloops) {
1050+
const std::vector<std::pair<LoopContext, llvm::Value *>> &containedloops,
1051+
const ValueToValueMapTy &available) {
10501052
// List of loop indices in chunk from innermost to outermost
10511053
SmallVector<Value *, 3> indices;
10521054
// List of cumulative indices in chunk from innermost to outermost
10531055
// where limit[i] = prod(loop limit[0..i])
10541056
SmallVector<Value *, 3> limits;
10551057

1056-
// list of contained loop induction variables available for limit
1057-
// computation
1058-
ValueToValueMapTy available;
1059-
10601058
// Iterate from innermost loop to outermost loop within a chunk
10611059
for (size_t i = 0; i < containedloops.size(); ++i) {
10621060
const auto &pair = containedloops[i];
@@ -1066,18 +1064,18 @@ Value *CacheUtility::computeIndexOfChunk(
10661064

10671065
// In the SingleIteration, var may be null (since there's no legal phinode)
10681066
// In that case the current iteration is simply the constnat Zero
1069-
if (var == nullptr)
1067+
if (idx.var == nullptr)
10701068
var = ConstantInt::get(Type::getInt64Ty(newFunc->getContext()), 0);
1071-
else if (!inForwardPass) {
1069+
else if (available.count(var)) {
1070+
var = available.find(var)->second;
1071+
} else if (!inForwardPass) {
10721072
#if LLVM_VERSION_MAJOR > 7
10731073
var = v.CreateLoad(idx.var->getType(), idx.antivaralloc);
10741074
#else
10751075
var = v.CreateLoad(idx.antivaralloc);
10761076
#endif
1077-
available[idx.var] = var;
10781077
} else {
10791078
var = idx.var;
1080-
available[idx.var] = var;
10811079
}
10821080
if (idx.offset) {
10831081
var = v.CreateAdd(var, lookupM(idx.offset, v), "", /*NUW*/ true,
@@ -1391,7 +1389,9 @@ void CacheUtility::storeInstructionInCache(LimitContext ctx,
13911389
bool isi1 = val->getType()->isIntegerTy() &&
13921390
cast<IntegerType>(val->getType())->getBitWidth() == 1;
13931391
Value *loc = getCachePointer(/*inForwardPass*/ true, v, ctx, cache, isi1,
1394-
/*storeInInstructionsMap*/ true);
1392+
/*storeInInstructionsMap*/ true,
1393+
/*available*/ llvm::ValueToValueMapTy(),
1394+
/*extraSize*/ nullptr);
13951395

13961396
Value *tostore = val;
13971397

@@ -1495,14 +1495,13 @@ void CacheUtility::storeInstructionInCache(LimitContext ctx,
14951495
Value *CacheUtility::getCachePointer(bool inForwardPass, IRBuilder<> &BuilderM,
14961496
LimitContext ctx, Value *cache, bool isi1,
14971497
bool storeInInstructionsMap,
1498+
const ValueToValueMapTy &available,
14981499
Value *extraSize) {
14991500
assert(ctx.Block);
15001501
assert(cache);
15011502

15021503
auto sublimits = getSubLimits(inForwardPass, &BuilderM, ctx, extraSize);
15031504

1504-
ValueToValueMapTy available;
1505-
15061505
Value *next = cache;
15071506
assert(next->getType()->isPointerTy());
15081507

@@ -1558,7 +1557,8 @@ Value *CacheUtility::getCachePointer(bool inForwardPass, IRBuilder<> &BuilderM,
15581557
const auto &containedloops = sublimits[i].second;
15591558

15601559
if (containedloops.size() > 0) {
1561-
Value *idx = computeIndexOfChunk(inForwardPass, BuilderM, containedloops);
1560+
Value *idx = computeIndexOfChunk(inForwardPass, BuilderM, containedloops,
1561+
available);
15621562
if (EfficientBoolCache && isi1 && i == 0)
15631563
idx = BuilderM.CreateLShr(
15641564
idx, ConstantInt::get(Type::getInt64Ty(newFunc->getContext()), 3));
@@ -1621,14 +1621,15 @@ llvm::Value *CacheUtility::loadFromCachePointer(llvm::IRBuilder<> &BuilderM,
16211621

16221622
/// Given an allocation specified by the LimitContext ctx and cache, lookup the
16231623
/// underlying cached value.
1624-
Value *CacheUtility::lookupValueFromCache(bool inForwardPass,
1625-
IRBuilder<> &BuilderM,
1626-
LimitContext ctx, Value *cache,
1627-
bool isi1, Value *extraSize,
1628-
Value *extraOffset) {
1624+
Value *
1625+
CacheUtility::lookupValueFromCache(bool inForwardPass, IRBuilder<> &BuilderM,
1626+
LimitContext ctx, Value *cache, bool isi1,
1627+
const ValueToValueMapTy &available,
1628+
Value *extraSize, Value *extraOffset) {
16291629
// Get the underlying cache pointer
1630-
auto cptr = getCachePointer(inForwardPass, BuilderM, ctx, cache, isi1,
1631-
/*storeInInstructionsMap*/ false, extraSize);
1630+
auto cptr =
1631+
getCachePointer(inForwardPass, BuilderM, ctx, cache, isi1,
1632+
/*storeInInstructionsMap*/ false, available, extraSize);
16321633

16331634
// Optionally apply the additional offset
16341635
if (extraOffset) {

enzyme/Enzyme/CacheUtility.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,8 @@ class CacheUtility {
266266
/// the IRBuilder<>
267267
llvm::Value *computeIndexOfChunk(
268268
bool inForwardPass, llvm::IRBuilder<> &v,
269-
const std::vector<std::pair<LoopContext, llvm::Value *>> &containedloops);
269+
const std::vector<std::pair<LoopContext, llvm::Value *>> &containedloops,
270+
const llvm::ValueToValueMapTy &available);
270271

271272
private:
272273
/// Given a cache allocation and an index denoting how many Chunks deep the
@@ -376,15 +377,15 @@ class CacheUtility {
376377
llvm::Value *getCachePointer(bool inForwardPass, llvm::IRBuilder<> &BuilderM,
377378
LimitContext ctx, llvm::Value *cache, bool isi1,
378379
bool storeInInstructionsMap,
379-
llvm::Value *extraSize = nullptr);
380+
const llvm::ValueToValueMapTy &available,
381+
llvm::Value *extraSize);
380382

381383
/// Given an allocation specified by the LimitContext ctx and cache, lookup
382384
/// the underlying cached value.
383-
llvm::Value *lookupValueFromCache(bool inForwardPass,
384-
llvm::IRBuilder<> &BuilderM,
385-
LimitContext ctx, llvm::Value *cache,
386-
bool isi1, llvm::Value *extraSize = nullptr,
387-
llvm::Value *extraOffset = nullptr);
385+
llvm::Value *lookupValueFromCache(
386+
bool inForwardPass, llvm::IRBuilder<> &BuilderM, LimitContext ctx,
387+
llvm::Value *cache, bool isi1, const llvm::ValueToValueMapTy &available,
388+
llvm::Value *extraSize = nullptr, llvm::Value *extraOffset = nullptr);
388389

389390
protected:
390391
// List of values loaded from the cache

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3535,13 +3535,22 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
35353535
auto newi = pair.first;
35363536
auto nexti = pair.second;
35373537
for (auto V : unwrapToOrig[newi]) {
3538-
ValueToValueMapTy empty;
3538+
ValueToValueMapTy available;
3539+
if (auto MD = hasMetadata(V, "enzyme_available")) {
3540+
for (auto &pair : MD->operands()) {
3541+
auto tup = cast<MDNode>(pair);
3542+
auto val = cast<ValueAsMetadata>(tup->getOperand(1))->getValue();
3543+
assert(val);
3544+
available[cast<ValueAsMetadata>(tup->getOperand(0))->getValue()] =
3545+
val;
3546+
}
3547+
}
35393548
IRBuilder<> lb(V);
35403549
// This must disallow caching here as otherwise performing the loop in
35413550
// the wrong order may result in first replacing the later unwrapped
35423551
// value, caching it, then attempting to reuse it for an earlier
35433552
// replacement.
3544-
Value *nval = gutils->unwrapM(nexti, lb, empty,
3553+
Value *nval = gutils->unwrapM(nexti, lb, available,
35453554
UnwrapMode::LegalFullUnwrapNoTapeReplace,
35463555
/*scope*/ nullptr, /*permitCache*/ false);
35473556
assert(nval);

0 commit comments

Comments
 (0)