@@ -2232,6 +2232,20 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2232
2232
gutils->newFunc ->getEntryBlock ().getFirstNonPHIOrDbgOrLifetime ());
2233
2233
}
2234
2234
2235
+ // ! Keep track of inverted pointers we may need to return
2236
+ ValueToValueMapTy invertedRetPs;
2237
+ if (shadowReturnUsed) {
2238
+ for (BasicBlock &BB : *gutils->oldFunc ) {
2239
+ if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator ())) {
2240
+ if (Value *orig_oldval = ri->getReturnValue ()) {
2241
+ auto newri = gutils->getNewFromOriginal (ri);
2242
+ IRBuilder<> BuilderZ (newri);
2243
+ invertedRetPs[newri] = gutils->invertPointerM (orig_oldval, BuilderZ);
2244
+ }
2245
+ }
2246
+ }
2247
+ }
2248
+
2235
2249
(IRBuilder<>(gutils->inversionAllocs )).CreateUnreachable ();
2236
2250
DeleteDeadBlock (gutils->inversionAllocs );
2237
2251
@@ -2290,20 +2304,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2290
2304
#endif
2291
2305
}
2292
2306
2293
- // ! Keep track of inverted pointers we may need to return
2294
- ValueToValueMapTy invertedRetPs;
2295
- if (shadowReturnUsed) {
2296
- for (BasicBlock &BB : *gutils->oldFunc ) {
2297
- if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator ())) {
2298
- if (Value *orig_oldval = ri->getReturnValue ()) {
2299
- auto newri = gutils->getNewFromOriginal (ri);
2300
- IRBuilder<> BuilderZ (newri);
2301
- invertedRetPs[newri] = gutils->invertPointerM (orig_oldval, BuilderZ);
2302
- }
2303
- }
2304
- }
2305
- }
2306
-
2307
2307
gutils->eraseFictiousPHIs ();
2308
2308
2309
2309
if (llvm::verifyFunction (*gutils->newFunc , &llvm::errs ())) {
@@ -2412,22 +2412,21 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2412
2412
Function *NewF = Function::Create (
2413
2413
FTy, nf->getLinkage (), " augmented_" + todiff->getName (), nf->getParent ());
2414
2414
2415
- unsigned ii = 0 , jj = 0 ;
2415
+ unsigned attrIndex = 0 ;
2416
2416
auto i = nf->arg_begin (), j = NewF->arg_begin ();
2417
- for (; i != nf->arg_end (); ) {
2417
+ while ( i != nf->arg_end ()) {
2418
2418
VMap[i] = j;
2419
- if (nf->hasParamAttribute (ii , Attribute::NoCapture)) {
2420
- NewF->addParamAttr (jj , Attribute::NoCapture);
2419
+ if (nf->hasParamAttribute (attrIndex , Attribute::NoCapture)) {
2420
+ NewF->addParamAttr (attrIndex , Attribute::NoCapture);
2421
2421
}
2422
- if (nf->hasParamAttribute (ii , Attribute::NoAlias)) {
2423
- NewF->addParamAttr (jj , Attribute::NoAlias);
2422
+ if (nf->hasParamAttribute (attrIndex , Attribute::NoAlias)) {
2423
+ NewF->addParamAttr (attrIndex , Attribute::NoAlias);
2424
2424
}
2425
2425
2426
2426
j->setName (i->getName ());
2427
2427
++j;
2428
- ++jj;
2429
2428
++i;
2430
- ++ii ;
2429
+ ++attrIndex ;
2431
2430
}
2432
2431
2433
2432
SmallVector<ReturnInst *, 4 > Returns;
@@ -2617,9 +2616,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2617
2616
if (auto ggep = dyn_cast<GetElementPtrInst>(gep)) {
2618
2617
ggep->setIsInBounds (true );
2619
2618
}
2620
- if (isa<ConstantData>(invertedRetPs[ri]))
2619
+ if (isa<ConstantExpr>(invertedRetPs[ri]) ||
2620
+ isa<ConstantData>(invertedRetPs[ri])) {
2621
2621
ib.CreateStore (invertedRetPs[ri], gep);
2622
- else {
2622
+ } else {
2623
2623
assert (VMap[invertedRetPs[ri]]);
2624
2624
ib.CreateStore (VMap[invertedRetPs[ri]], gep);
2625
2625
}
0 commit comments