Skip to content

Commit 4eb8421

Browse files
ZuseZ4tgymnich
andauthored
fix ConstantExpr handling in CreateAugmentedPrimal (rust-lang#743)
* fix ConstantExpr handling in CreateAugmentedPrimal * add testcase * fix testcase * Update constexpr.ll * respect lifetime Co-authored-by: Tim Gymnich <[email protected]>
1 parent e8ed87c commit 4eb8421

File tree

2 files changed

+50
-24
lines changed

2 files changed

+50
-24
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,20 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
22322232
gutils->newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
22332233
}
22342234

2235+
//! Keep track of inverted pointers we may need to return
2236+
ValueToValueMapTy invertedRetPs;
2237+
if (shadowReturnUsed) {
2238+
for (BasicBlock &BB : *gutils->oldFunc) {
2239+
if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator())) {
2240+
if (Value *orig_oldval = ri->getReturnValue()) {
2241+
auto newri = gutils->getNewFromOriginal(ri);
2242+
IRBuilder<> BuilderZ(newri);
2243+
invertedRetPs[newri] = gutils->invertPointerM(orig_oldval, BuilderZ);
2244+
}
2245+
}
2246+
}
2247+
}
2248+
22352249
(IRBuilder<>(gutils->inversionAllocs)).CreateUnreachable();
22362250
DeleteDeadBlock(gutils->inversionAllocs);
22372251

@@ -2290,20 +2304,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
22902304
#endif
22912305
}
22922306

2293-
//! Keep track of inverted pointers we may need to return
2294-
ValueToValueMapTy invertedRetPs;
2295-
if (shadowReturnUsed) {
2296-
for (BasicBlock &BB : *gutils->oldFunc) {
2297-
if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator())) {
2298-
if (Value *orig_oldval = ri->getReturnValue()) {
2299-
auto newri = gutils->getNewFromOriginal(ri);
2300-
IRBuilder<> BuilderZ(newri);
2301-
invertedRetPs[newri] = gutils->invertPointerM(orig_oldval, BuilderZ);
2302-
}
2303-
}
2304-
}
2305-
}
2306-
23072307
gutils->eraseFictiousPHIs();
23082308

23092309
if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) {
@@ -2412,22 +2412,21 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
24122412
Function *NewF = Function::Create(
24132413
FTy, nf->getLinkage(), "augmented_" + todiff->getName(), nf->getParent());
24142414

2415-
unsigned ii = 0, jj = 0;
2415+
unsigned attrIndex = 0;
24162416
auto i = nf->arg_begin(), j = NewF->arg_begin();
2417-
for (; i != nf->arg_end();) {
2417+
while (i != nf->arg_end()) {
24182418
VMap[i] = j;
2419-
if (nf->hasParamAttribute(ii, Attribute::NoCapture)) {
2420-
NewF->addParamAttr(jj, Attribute::NoCapture);
2419+
if (nf->hasParamAttribute(attrIndex, Attribute::NoCapture)) {
2420+
NewF->addParamAttr(attrIndex, Attribute::NoCapture);
24212421
}
2422-
if (nf->hasParamAttribute(ii, Attribute::NoAlias)) {
2423-
NewF->addParamAttr(jj, Attribute::NoAlias);
2422+
if (nf->hasParamAttribute(attrIndex, Attribute::NoAlias)) {
2423+
NewF->addParamAttr(attrIndex, Attribute::NoAlias);
24242424
}
24252425

24262426
j->setName(i->getName());
24272427
++j;
2428-
++jj;
24292428
++i;
2430-
++ii;
2429+
++attrIndex;
24312430
}
24322431

24332432
SmallVector<ReturnInst *, 4> Returns;
@@ -2617,9 +2616,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
26172616
if (auto ggep = dyn_cast<GetElementPtrInst>(gep)) {
26182617
ggep->setIsInBounds(true);
26192618
}
2620-
if (isa<ConstantData>(invertedRetPs[ri]))
2619+
if (isa<ConstantExpr>(invertedRetPs[ri]) ||
2620+
isa<ConstantData>(invertedRetPs[ri])) {
26212621
ib.CreateStore(invertedRetPs[ri], gep);
2622-
else {
2622+
} else {
26232623
assert(VMap[invertedRetPs[ri]]);
26242624
ib.CreateStore(VMap[invertedRetPs[ri]], gep);
26252625
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -S | FileCheck %s
2+
3+
@_ZTId = external dso_local constant i8*
4+
5+
define i8* @_ZNK4implIdE4typeEv() {
6+
ret i8* bitcast (i8** @_ZTId to i8*)
7+
}
8+
9+
declare void @_Z17__enzyme_virtualreverse(i8*)
10+
11+
define void @_Z18wrapper_1body_intsv() {
12+
call void @_Z17__enzyme_virtualreverse(i8* bitcast (i8* ()* @_ZNK4implIdE4typeEv to i8*))
13+
ret void
14+
}
15+
16+
; CHECK: define internal { i8*, i8*, i8* } @augmented__ZNK4implIdE4typeEv()
17+
; CHECK-NEXT: %1 = alloca { i8*, i8*, i8* }
18+
; CHECK-NEXT: %2 = getelementptr inbounds { i8*, i8*, i8* }, { i8*, i8*, i8* }* %1, i32 0, i32 0
19+
; CHECK-NEXT: store i8* null, i8** %2
20+
; CHECK-NEXT: %3 = getelementptr inbounds { i8*, i8*, i8* }, { i8*, i8*, i8* }* %1, i32 0, i32 1
21+
; CHECK-NEXT: store i8* bitcast (i8** @_ZTId to i8*), i8** %3
22+
; CHECK-NEXT: %4 = getelementptr inbounds { i8*, i8*, i8* }, { i8*, i8*, i8* }* %1, i32 0, i32 2
23+
; CHECK-NEXT: store i8* bitcast (i8** @_ZTId_shadow to i8*), i8** %4
24+
; CHECK-NEXT: %5 = load { i8*, i8*, i8* }, { i8*, i8*, i8* }* %1
25+
; CHECK-NEXT: ret { i8*, i8*, i8* } %5
26+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)