Skip to content

Commit 3d28f83

Browse files
authored
Fix recursive global (rust-lang#704)
1 parent ca6fe58 commit 3d28f83

File tree

6 files changed

+185
-36
lines changed

6 files changed

+185
-36
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10864,8 +10864,30 @@ class AdjointGenerator
1086410864
(orig->mayWriteToMemory() ||
1086510865
!gutils->legalRecompute(orig, ValueToValueMapTy(), nullptr))) {
1086610866
if (!gutils->unnecessaryIntermediates.count(orig)) {
10867-
gutils->cacheForReverse(BuilderZ, newCall,
10868-
getIndex(orig, CacheType::Self));
10867+
10868+
std::map<UsageKey, bool> Seen;
10869+
bool primalNeededInReverse = false;
10870+
for (auto pair : gutils->knownRecomputeHeuristic)
10871+
if (!pair.second) {
10872+
if (pair.first == orig) {
10873+
primalNeededInReverse = true;
10874+
break;
10875+
} else {
10876+
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
10877+
}
10878+
}
10879+
if (!primalNeededInReverse) {
10880+
10881+
auto minCutMode = (Mode == DerivativeMode::ReverseModePrimal)
10882+
? DerivativeMode::ReverseModeGradient
10883+
: Mode;
10884+
primalNeededInReverse =
10885+
is_value_needed_in_reverse<ValueType::Primal>(
10886+
gutils, orig, minCutMode, Seen, oldUnreachable);
10887+
}
10888+
if (primalNeededInReverse)
10889+
gutils->cacheForReverse(BuilderZ, newCall,
10890+
getIndex(orig, CacheType::Self));
1086910891
}
1087010892
eraseIfUnused(*orig);
1087110893
return;
@@ -11460,11 +11482,31 @@ class AdjointGenerator
1146011482
}
1146111483

1146211484
if (Mode == DerivativeMode::ReverseModePrimal &&
11463-
is_value_needed_in_reverse<ValueType::Primal>(gutils, orig, Mode,
11464-
oldUnreachable) &&
1146511485
!gutils->unnecessaryIntermediates.count(orig)) {
11466-
gutils->cacheForReverse(BuilderZ, dcall,
11467-
getIndex(orig, CacheType::Self));
11486+
11487+
std::map<UsageKey, bool> Seen;
11488+
bool primalNeededInReverse = false;
11489+
for (auto pair : gutils->knownRecomputeHeuristic)
11490+
if (!pair.second) {
11491+
if (pair.first == orig) {
11492+
primalNeededInReverse = true;
11493+
break;
11494+
} else {
11495+
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
11496+
}
11497+
}
11498+
if (!primalNeededInReverse) {
11499+
11500+
auto minCutMode = (Mode == DerivativeMode::ReverseModePrimal)
11501+
? DerivativeMode::ReverseModeGradient
11502+
: Mode;
11503+
primalNeededInReverse =
11504+
is_value_needed_in_reverse<ValueType::Primal>(
11505+
gutils, orig, minCutMode, Seen, oldUnreachable);
11506+
}
11507+
if (primalNeededInReverse)
11508+
gutils->cacheForReverse(BuilderZ, dcall,
11509+
getIndex(orig, CacheType::Self));
1146811510
}
1146911511
BuilderZ.SetInsertPoint(newCall->getNextNode());
1147011512
gutils->erase(newCall);

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -275,23 +275,33 @@ static inline bool is_value_needed_in_reverse(
275275
return seen[idx] = true;
276276

277277
if (auto SI = dyn_cast<StoreInst>(user)) {
278-
// storing an active pointer into a location
279-
// doesn't require the shadow pointer for the
280-
// reverse pass
281-
if (SI->getValueOperand() == inst &&
282-
(mode == DerivativeMode::ReverseModeGradient ||
283-
mode == DerivativeMode::ForwardModeSplit)) {
284-
// Unless the store is into a backwards store, which would
285-
// would then be performed in the reverse if the stored value was
286-
// a possible pointer.
278+
if (mode == DerivativeMode::ReverseModeGradient ||
279+
mode == DerivativeMode::ForwardModeSplit) {
280+
287281
bool rematerialized = false;
288282
for (auto pair : gutils->backwardsOnlyShadows)
289283
if (pair.second.stores.count(SI)) {
290284
rematerialized = true;
291285
break;
292286
}
293-
if (!rematerialized)
294-
goto endShadow;
287+
288+
if (SI->getValueOperand() == inst) {
289+
// storing an active pointer into a location
290+
// doesn't require the shadow pointer for the
291+
// reverse pass
292+
// Unless the store is into a backwards store, which would
293+
// would then be performed in the reverse if the stored value was
294+
// a possible pointer.
295+
if (!rematerialized)
296+
goto endShadow;
297+
} else {
298+
// Likewise, if not rematerializing in reverse pass, you
299+
// don't need to keep the pointer operand for known pointers
300+
if (!rematerialized &&
301+
TR.query(const_cast<Value *>(SI->getValueOperand()))[{-1}] ==
302+
BaseType::Pointer)
303+
goto endShadow;
304+
}
295305
}
296306

297307
if (!gutils->isConstantValue(

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2472,8 +2472,36 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
24722472
}
24732473

24742474
SmallVector<CallInst *, 4> fnusers;
2475+
SmallVector<std::pair<GlobalVariable *, DerivativeMode>, 1> gfnusers;
24752476
for (auto user : AugmentedCachedFunctions.find(tup)->second.fn->users()) {
2476-
fnusers.push_back(cast<CallInst>(user));
2477+
if (auto CI = dyn_cast<CallInst>(user)) {
2478+
fnusers.push_back(CI);
2479+
} else {
2480+
if (auto CS = dyn_cast<ConstantStruct>(user)) {
2481+
for (auto cuser : CS->users()) {
2482+
if (auto G = dyn_cast<GlobalVariable>(cuser)) {
2483+
if (("_enzyme_reverse_" + todiff->getName() + "'").str() ==
2484+
G->getName()) {
2485+
gfnusers.emplace_back(G, DerivativeMode::ReverseModeGradient);
2486+
continue;
2487+
}
2488+
if (("_enzyme_forwardsplit_" + todiff->getName() + "'").str() ==
2489+
G->getName()) {
2490+
gfnusers.emplace_back(G, DerivativeMode::ForwardModeSplit);
2491+
continue;
2492+
}
2493+
}
2494+
llvm::errs() << *gutils->newFunc->getParent() << "\n";
2495+
llvm::errs() << *cuser << "\n";
2496+
llvm::errs() << *user << "\n";
2497+
llvm_unreachable("Bad cuser of staging augmented forward fn");
2498+
}
2499+
continue;
2500+
}
2501+
llvm::errs() << *gutils->newFunc->getParent() << "\n";
2502+
llvm::errs() << *user << "\n";
2503+
llvm_unreachable("Bad user of staging augmented forward fn");
2504+
}
24772505
}
24782506
for (auto user : fnusers) {
24792507
if (removeStruct) {
@@ -2514,6 +2542,25 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
25142542
AugmentedCachedFunctions.find(tup)->second.tapeType = tapeType;
25152543
insert_or_assign(AugmentedCachedFinished, tup, true);
25162544

2545+
for (auto pair : gfnusers) {
2546+
auto GV = pair.first;
2547+
GV->setName("_tmp");
2548+
auto R = gutils->GetOrCreateShadowFunction(
2549+
*this, TLI, TA, todiff, pair.second, width, gutils->AtomicAdd);
2550+
SmallVector<ConstantExpr *, 1> users;
2551+
for (auto U : GV->users()) {
2552+
if (auto CE = dyn_cast<ConstantExpr>(U)) {
2553+
if (CE->isCast()) {
2554+
users.push_back(CE);
2555+
}
2556+
}
2557+
}
2558+
for (auto U : users) {
2559+
U->replaceAllUsesWith(ConstantExpr::getPointerCast(R, U->getType()));
2560+
}
2561+
GV->eraseFromParent();
2562+
}
2563+
25172564
{
25182565
PreservedAnalyses PA;
25192566
PPC.FAM.invalidate(*gutils->newFunc, PA);

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3582,13 +3582,9 @@ Constant *GradientUtils::GetOrCreateShadowConstant(
35823582
Type *type = arg->getType()->getPointerElementType();
35833583
auto shadow = new GlobalVariable(
35843584
*arg->getParent(), type, arg->isConstant(), arg->getLinkage(),
3585-
arg->getInitializer()
3586-
? GetOrCreateShadowConstant(Logic, TLI, TA,
3587-
cast<Constant>(arg->getOperand(0)),
3588-
mode, width, AtomicAdd)
3589-
: Constant::getNullValue(type),
3590-
arg->getName() + "_shadow", arg, arg->getThreadLocalMode(),
3591-
arg->getType()->getAddressSpace(), arg->isExternallyInitialized());
3585+
Constant::getNullValue(type), arg->getName() + "_shadow", arg,
3586+
arg->getThreadLocalMode(), arg->getType()->getAddressSpace(),
3587+
arg->isExternallyInitialized());
35923588
arg->setMetadata("enzyme_shadow",
35933589
MDTuple::get(shadow->getContext(),
35943590
{ConstantAsMetadata::get(shadow)}));
@@ -3598,6 +3594,10 @@ Constant *GradientUtils::GetOrCreateShadowConstant(
35983594
shadow->setAlignment(arg->getAlignment());
35993595
#endif
36003596
shadow->setUnnamedAddr(arg->getUnnamedAddr());
3597+
if (arg->getInitializer())
3598+
shadow->setInitializer(GetOrCreateShadowConstant(
3599+
Logic, TLI, TA, cast<Constant>(arg->getOperand(0)), mode, width,
3600+
AtomicAdd));
36013601
return shadow;
36023602
}
36033603
}
@@ -4103,14 +4103,11 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
41034103
Type *elemTy = arg->getType()->getPointerElementType();
41044104
Type *type = getShadowType(elemTy);
41054105
IRBuilder<> B(inversionAllocs);
4106-
auto ip = arg->getInitializer() ? invertPointerM(arg->getInitializer(),
4107-
B, /*nullShadow*/ true)
4108-
: Constant::getNullValue(type);
41094106

4110-
auto rule = [&](Value *ip) {
4107+
auto rule = [&]() {
41114108
auto shadow = new GlobalVariable(
41124109
*arg->getParent(), elemTy, arg->isConstant(), arg->getLinkage(),
4113-
cast<Constant>(ip), arg->getName() + "_shadow", arg,
4110+
Constant::getNullValue(type), arg->getName() + "_shadow", arg,
41144111
arg->getThreadLocalMode(), arg->getType()->getAddressSpace(),
41154112
arg->isExternallyInitialized());
41164113
arg->setMetadata("enzyme_shadow",
@@ -4126,7 +4123,18 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
41264123
return shadow;
41274124
};
41284125

4129-
Value *shadow = applyChainRule(oval->getType(), BuilderM, rule, ip);
4126+
Value *shadow = applyChainRule(oval->getType(), BuilderM, rule);
4127+
4128+
if (arg->getInitializer()) {
4129+
applyChainRule(
4130+
BuilderM,
4131+
[&](Value *shadow, Value *ip) {
4132+
cast<GlobalVariable>(shadow)->setInitializer(
4133+
cast<Constant>(ip));
4134+
},
4135+
shadow,
4136+
invertPointerM(arg->getInitializer(), B, /*nullShadow*/ true));
4137+
}
41304138

41314139
invertedPointers.insert(std::make_pair(
41324140
(const Value *)oval, InvertedPointerVH(this, shadow)));
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
@g = constant i8* bitcast (void (i8***)* @impl to i8*), align 8
4+
5+
declare i32 @offset()
6+
7+
define void @impl(i8*** %i) {
8+
%o = call i32 @offset()
9+
%g = getelementptr inbounds i8**, i8*** %i, i32 %o
10+
store i8** @g, i8*** %g, align 8
11+
ret void
12+
}
13+
14+
define i8** @caller() {
15+
%i6 = call i8** (...) @__enzyme_virtualreverse(i8** @g)
16+
ret i8** %i6
17+
}
18+
19+
declare i8** @__enzyme_virtualreverse(...)
20+
21+
; CHECK: @g_shadow = constant i8* bitcast ({ i8* (i8***, i8***)*, void (i8***, i8***, i8*)* }* @"_enzyme_reverse_impl'" to i8*), align 8
22+
; CHECK: @g = constant i8* bitcast (void (i8***)* @impl to i8*), align 8, !enzyme_shadow !0
23+
; CHECK: @"_enzyme_reverse_impl'" = internal constant { i8* (i8***, i8***)*, void (i8***, i8***, i8*)* } { i8* (i8***, i8***)* @augmented_impl, void (i8***, i8***, i8*)* @diffeimpl }
24+
25+
; CHECK: define i8** @caller() {
26+
; CHECK-NEXT: ret i8** @g_shadow
27+
; CHECK-NEXT: }
28+
29+
; CHECK: define internal i8* @augmented_impl(i8*** %i, i8*** %"i'")
30+
; CHECK-NEXT: %o = call i32 @offset()
31+
; CHECK-NEXT: %"g'ipg" = getelementptr inbounds i8**, i8*** %"i'", i32 %o
32+
; CHECK-NEXT: %g = getelementptr inbounds i8**, i8*** %i, i32 %o
33+
; CHECK-NEXT: store i8** @g_shadow, i8*** %"g'ipg", align 8
34+
; CHECK-NEXT: store i8** @g, i8*** %g, align 8
35+
; CHECK-NEXT: ret i8* null
36+
; CHECK-NEXT: }
37+
38+
; CHECK: define internal void @diffeimpl(i8*** %i, i8*** %"i'", i8* %tapeArg)
39+
; CHECK-NEXT: invert:
40+
; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg)
41+
; CHECK-NEXT: ret void
42+
; CHECK-NEXT: }
43+
44+
; CHECK: !0 = !{i8** @g_shadow}

enzyme/test/Enzyme/ReverseMode/unnecessaryalloc.ll

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,18 @@ declare void @free(i8*)
5151
!5 = distinct !{}
5252

5353

54-
; CHECK: define internal { i64, double } @augmented_diffemy_sin2(double %x, double %differeturn)
54+
; CHECK: define internal double @augmented_diffemy_sin2(double %x, double %differeturn)
5555
; CHECK-NEXT: entry:
5656
; CHECK-NEXT: %i = add i64 14, 1
5757
; CHECK-NEXT: %i1 = add nuw i64 %i, 1
5858
; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %i1, 8
5959
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize)
6060
; CHECK-NEXT: tail call void @free(i8* nonnull %malloccall)
6161
; CHECK-NEXT: %0 = insertvalue { i8*, double } undef, double %x, 1
62-
; CHECK-NEXT: %.fca.0.insert = insertvalue { i64, double } {{(undef|poison)}}, i64 14, 0
63-
; CHECK-NEXT: %.fca.1.insert = insertvalue { i64, double } %.fca.0.insert, double %x, 1
64-
; CHECK-NEXT: ret { i64, double } %.fca.1.insert
62+
; CHECK-NEXT: ret double %x
6563
; CHECK-NEXT: }
6664

67-
; CHECK: define internal { double } @diffediffemy_sin2(double %x, double %differeturn1, double %differeturn, i64 %tapeArg)
65+
; CHECK: define internal { double } @diffediffemy_sin2(double %x, double %differeturn1, double %differeturn)
6866
; CHECK-NEXT: entry:
6967
; CHECK-NEXT: %0 = insertvalue { double } undef, double %differeturn, 0
7068
; CHECK-NEXT: ret { double } %0

0 commit comments

Comments
 (0)