Skip to content

Commit 9a5312e

Browse files
committed
SCEV-based recompute (#185)
1 parent 1a7445e commit 9a5312e

File tree

7 files changed

+419
-36
lines changed

7 files changed

+419
-36
lines changed

enzyme/Enzyme/CacheUtility.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -223,22 +223,6 @@ void RemoveRedundantIVs(BasicBlock *Header, PHINode *CanonicalIV,
223223
if (NewIV == PN) {
224224
continue;
225225
}
226-
if (auto BO = dyn_cast<BinaryOperator>(NewIV)) {
227-
if (BO->getOpcode() == BinaryOperator::Add ||
228-
BO->getOpcode() == BinaryOperator::Mul) {
229-
BO->setHasNoSignedWrap(true);
230-
BO->setHasNoUnsignedWrap(true);
231-
}
232-
for (int i = 0; i < 2; ++i) {
233-
if (auto BO2 = dyn_cast<BinaryOperator>(BO->getOperand(i))) {
234-
if (BO2->getOpcode() == BinaryOperator::Add ||
235-
BO2->getOpcode() == BinaryOperator::Mul) {
236-
BO2->setHasNoSignedWrap(true);
237-
BO2->setHasNoUnsignedWrap(true);
238-
}
239-
}
240-
}
241-
}
242226

243227
replacer(PN, NewIV);
244228
IVsToRemove.push_back(PN);

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 130 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "SCEV/ScalarEvolution.h"
3333
#include "SCEV/ScalarEvolutionExpander.h"
3434

35+
#include "llvm/Analysis/DependenceAnalysis.h"
3536
#include <deque>
3637

3738
#include "llvm/IR/BasicBlock.h"
@@ -92,6 +93,7 @@ bool is_load_uncacheable(
9293
struct CacheAnalysis {
9394
AAResults &AA;
9495
Function *oldFunc;
96+
ScalarEvolution &SE;
9597
LoopInfo &OrigLI;
9698
DominatorTree &DT;
9799
TargetLibraryInfo &TLI;
@@ -100,11 +102,11 @@ struct CacheAnalysis {
100102
bool topLevel;
101103
std::map<Value *, bool> seen;
102104
CacheAnalysis(
103-
AAResults &AA, Function *oldFunc, LoopInfo &OrigLI, DominatorTree &OrigDT,
104-
TargetLibraryInfo &TLI,
105+
AAResults &AA, Function *oldFunc, ScalarEvolution &SE, LoopInfo &OrigLI,
106+
DominatorTree &OrigDT, TargetLibraryInfo &TLI,
105107
const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
106108
const std::map<Argument *, bool> &uncacheable_args, bool topLevel)
107-
: AA(AA), oldFunc(oldFunc), OrigLI(OrigLI), DT(OrigDT), TLI(TLI),
109+
: AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), DT(OrigDT), TLI(TLI),
108110
unnecessaryInstructions(unnecessaryInstructions),
109111
uncacheable_args(uncacheable_args), topLevel(topLevel) {}
110112

@@ -252,6 +254,123 @@ struct CacheAnalysis {
252254
if (!writesToMemoryReadBy(AA, &li, inst2)) {
253255
return false;
254256
}
257+
258+
if (auto SI = dyn_cast<StoreInst>(inst2)) {
259+
260+
const SCEV *LS = SE.getSCEV(li.getPointerOperand());
261+
const SCEV *SS = SE.getSCEV(SI->getPointerOperand());
262+
if (SS != SE.getCouldNotCompute()) {
263+
264+
// llvm::errs() << *inst2 << " - " << li << "\n";
265+
// llvm::errs() << *SS << " - " << *LS << "\n";
266+
const auto &DL = li.getModule()->getDataLayout();
267+
268+
#if LLVM_VERSION_MAJOR >= 10
269+
auto TS = SE.getConstant(
270+
APInt(64, DL.getTypeStoreSize(li.getType()).getFixedSize()));
271+
#else
272+
auto TS = SE.getConstant(
273+
APInt(64, DL.getTypeStoreSize(li.getType())));
274+
#endif
275+
for (auto lim = LS; lim != SE.getCouldNotCompute();) {
276+
// [start load, L+Size] [S, S+Size]
277+
for (auto slim = SS; slim != SE.getCouldNotCompute();) {
278+
auto lsub = SE.getMinusSCEV(slim, SE.getAddExpr(lim, TS));
279+
// llvm::errs() << " *** " << *lsub << "|" << *slim << "|" <<
280+
// *lim << "\n";
281+
if (SE.isKnownNonNegative(lsub)) {
282+
return false;
283+
}
284+
if (auto arL = dyn_cast<SCEVAddRecExpr>(slim)) {
285+
if (SE.isKnownNonNegative(arL->getStepRecurrence(SE))) {
286+
slim = arL->getStart();
287+
continue;
288+
} else if (SE.isKnownNonPositive(
289+
arL->getStepRecurrence(SE))) {
290+
#if LLVM_VERSION_MAJOR >= 12
291+
auto bd =
292+
SE.getSymbolicMaxBackedgeTakenCount(arL->getLoop());
293+
#else
294+
auto bd = SE.getBackedgeTakenCount(arL->getLoop());
295+
#endif
296+
if (bd == SE.getCouldNotCompute())
297+
break;
298+
slim = arL->evaluateAtIteration(bd, SE);
299+
continue;
300+
}
301+
}
302+
break;
303+
}
304+
305+
if (auto arL = dyn_cast<SCEVAddRecExpr>(lim)) {
306+
if (SE.isKnownNonNegative(arL->getStepRecurrence(SE))) {
307+
#if LLVM_VERSION_MAJOR >= 12
308+
auto bd = SE.getSymbolicMaxBackedgeTakenCount(arL->getLoop());
309+
#else
310+
auto bd = SE.getBackedgeTakenCount(arL->getLoop());
311+
#endif
312+
if (bd == SE.getCouldNotCompute())
313+
break;
314+
lim = arL->evaluateAtIteration(bd, SE);
315+
continue;
316+
} else if (SE.isKnownNonPositive(arL->getStepRecurrence(SE))) {
317+
lim = arL->getStart();
318+
continue;
319+
}
320+
}
321+
break;
322+
}
323+
for (auto lim = LS; lim != SE.getCouldNotCompute();) {
324+
// [S, S+Size][start load, L+Size]
325+
for (auto slim = SS; slim != SE.getCouldNotCompute();) {
326+
auto lsub = SE.getMinusSCEV(lim, SE.getAddExpr(slim, TS));
327+
// llvm::errs() << " $$$ " << *lsub << "|" << *slim << "|" <<
328+
// *lim << "\n";
329+
if (SE.isKnownNonNegative(lsub)) {
330+
return false;
331+
}
332+
if (auto arL = dyn_cast<SCEVAddRecExpr>(slim)) {
333+
if (SE.isKnownNonNegative(arL->getStepRecurrence(SE))) {
334+
#if LLVM_VERSION_MAJOR >= 12
335+
auto bd =
336+
SE.getSymbolicMaxBackedgeTakenCount(arL->getLoop());
337+
#else
338+
auto bd = SE.getBackedgeTakenCount(arL->getLoop());
339+
#endif
340+
if (bd == SE.getCouldNotCompute())
341+
break;
342+
slim = arL->evaluateAtIteration(bd, SE);
343+
continue;
344+
} else if (SE.isKnownNonPositive(
345+
arL->getStepRecurrence(SE))) {
346+
slim = arL->getStart();
347+
continue;
348+
}
349+
}
350+
break;
351+
}
352+
353+
if (auto arL = dyn_cast<SCEVAddRecExpr>(lim)) {
354+
if (SE.isKnownNonNegative(arL->getStepRecurrence(SE))) {
355+
lim = arL->getStart();
356+
continue;
357+
} else if (SE.isKnownNonPositive(arL->getStepRecurrence(SE))) {
358+
#if LLVM_VERSION_MAJOR >= 12
359+
auto bd = SE.getSymbolicMaxBackedgeTakenCount(arL->getLoop());
360+
#else
361+
auto bd = SE.getBackedgeTakenCount(arL->getLoop());
362+
#endif
363+
if (bd == SE.getCouldNotCompute())
364+
break;
365+
lim = arL->evaluateAtIteration(bd, SE);
366+
continue;
367+
}
368+
}
369+
break;
370+
}
371+
}
372+
}
373+
255374
if (auto II = dyn_cast<IntrinsicInst>(inst2)) {
256375
if (II->getIntrinsicID() == Intrinsic::nvvm_barrier0 ||
257376
II->getIntrinsicID() == Intrinsic::amdgcn_s_barrier) {
@@ -1309,9 +1428,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
13091428
for (auto &I : *BB)
13101429
unnecessaryInstructionsTmp.insert(&I);
13111430
}
1312-
CacheAnalysis CA(gutils->OrigAA, gutils->oldFunc, gutils->OrigLI,
1313-
gutils->OrigDT, TLI, unnecessaryInstructionsTmp,
1314-
_uncacheable_argsPP,
1431+
CacheAnalysis CA(gutils->OrigAA, gutils->oldFunc,
1432+
PPC.FAM.getResult<ScalarEvolutionAnalysis>(*gutils->oldFunc),
1433+
gutils->OrigLI, gutils->OrigDT, TLI,
1434+
unnecessaryInstructionsTmp, _uncacheable_argsPP,
13151435
/*topLevel*/ false);
13161436
const std::map<CallInst *, const std::map<Argument *, bool>>
13171437
uncacheable_args_map = CA.compute_uncacheable_args_for_callsites();
@@ -2434,9 +2554,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
24342554
for (auto &I : *BB)
24352555
unnecessaryInstructionsTmp.insert(&I);
24362556
}
2437-
CacheAnalysis CA(gutils->OrigAA, gutils->oldFunc, gutils->OrigLI,
2438-
gutils->OrigDT, TLI, unnecessaryInstructionsTmp,
2439-
_uncacheable_argsPP, topLevel);
2557+
CacheAnalysis CA(gutils->OrigAA, gutils->oldFunc,
2558+
PPC.FAM.getResult<ScalarEvolutionAnalysis>(*gutils->oldFunc),
2559+
gutils->OrigLI, gutils->OrigDT, TLI,
2560+
unnecessaryInstructionsTmp, _uncacheable_argsPP, topLevel);
24402561
const std::map<CallInst *, const std::map<Argument *, bool>>
24412562
uncacheable_args_map =
24422563
(augmenteddata) ? augmenteddata->uncacheable_args_map

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@
4747
#include "llvm/Analysis/MemorySSA.h"
4848
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
4949

50-
#include "llvm/CodeGen/UnreachableBlockElim.h"
51-
50+
#include "llvm/Analysis/DependenceAnalysis.h"
5251
#include "llvm/Analysis/TypeBasedAliasAnalysis.h"
52+
#include "llvm/CodeGen/UnreachableBlockElim.h"
5353

5454
#include "llvm/Analysis/CFLSteensAliasAnalysis.h"
5555

@@ -98,6 +98,8 @@
9898
#include "llvm/IR/LegacyPassManager.h"
9999
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
100100

101+
#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
102+
101103
#include "CacheUtility.h"
102104

103105
#define DEBUG_TYPE "enzyme"
@@ -669,11 +671,14 @@ PreProcessCache::PreProcessCache() {
669671
FAM.registerPass([] { return PhiValuesAnalysis(); });
670672
#endif
671673

674+
FAM.registerPass([] { return DependenceAnalysis(); });
675+
672676
// Explicitly chose AA passes that are stateless
673677
// and will not be invalidated
674678
FAM.registerPass([] { return TypeBasedAA(); });
675679
FAM.registerPass([] { return BasicAA(); });
676680
MAM.registerPass([] { return GlobalsAA(); });
681+
677682
// SCEVAA causes some breakage/segfaults
678683
// disable for now, consider enabling in future
679684
// FAM.registerPass([] { return SCEVAA(); });

0 commit comments

Comments
 (0)