32
32
#include " SCEV/ScalarEvolution.h"
33
33
#include " SCEV/ScalarEvolutionExpander.h"
34
34
35
+ #include " llvm/Analysis/DependenceAnalysis.h"
35
36
#include < deque>
36
37
37
38
#include " llvm/IR/BasicBlock.h"
@@ -92,6 +93,7 @@ bool is_load_uncacheable(
92
93
struct CacheAnalysis {
93
94
AAResults &AA;
94
95
Function *oldFunc;
96
+ ScalarEvolution &SE;
95
97
LoopInfo &OrigLI;
96
98
DominatorTree &DT;
97
99
TargetLibraryInfo &TLI;
@@ -100,11 +102,11 @@ struct CacheAnalysis {
100
102
bool topLevel;
101
103
std::map<Value *, bool > seen;
102
104
CacheAnalysis (
103
- AAResults &AA, Function *oldFunc, LoopInfo &OrigLI, DominatorTree &OrigDT ,
104
- TargetLibraryInfo &TLI,
105
+ AAResults &AA, Function *oldFunc, ScalarEvolution &SE, LoopInfo &OrigLI ,
106
+ DominatorTree &OrigDT, TargetLibraryInfo &TLI,
105
107
const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
106
108
const std::map<Argument *, bool > &uncacheable_args, bool topLevel)
107
- : AA(AA), oldFunc(oldFunc), OrigLI(OrigLI), DT(OrigDT), TLI(TLI),
109
+ : AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), DT(OrigDT), TLI(TLI),
108
110
unnecessaryInstructions (unnecessaryInstructions),
109
111
uncacheable_args(uncacheable_args), topLevel(topLevel) {}
110
112
@@ -252,6 +254,123 @@ struct CacheAnalysis {
252
254
if (!writesToMemoryReadBy (AA, &li, inst2)) {
253
255
return false ;
254
256
}
257
+
258
+ if (auto SI = dyn_cast<StoreInst>(inst2)) {
259
+
260
+ const SCEV *LS = SE.getSCEV (li.getPointerOperand ());
261
+ const SCEV *SS = SE.getSCEV (SI->getPointerOperand ());
262
+ if (SS != SE.getCouldNotCompute ()) {
263
+
264
+ // llvm::errs() << *inst2 << " - " << li << "\n";
265
+ // llvm::errs() << *SS << " - " << *LS << "\n";
266
+ const auto &DL = li.getModule ()->getDataLayout ();
267
+
268
+ #if LLVM_VERSION_MAJOR >= 10
269
+ auto TS = SE.getConstant (
270
+ APInt (64 , DL.getTypeStoreSize (li.getType ()).getFixedSize ()));
271
+ #else
272
+ auto TS = SE.getConstant (
273
+ APInt (64 , DL.getTypeStoreSize (li.getType ())));
274
+ #endif
275
+ for (auto lim = LS; lim != SE.getCouldNotCompute ();) {
276
+ // [start load, L+Size] [S, S+Size]
277
+ for (auto slim = SS; slim != SE.getCouldNotCompute ();) {
278
+ auto lsub = SE.getMinusSCEV (slim, SE.getAddExpr (lim, TS));
279
+ // llvm::errs() << " *** " << *lsub << "|" << *slim << "|" <<
280
+ // *lim << "\n";
281
+ if (SE.isKnownNonNegative (lsub)) {
282
+ return false ;
283
+ }
284
+ if (auto arL = dyn_cast<SCEVAddRecExpr>(slim)) {
285
+ if (SE.isKnownNonNegative (arL->getStepRecurrence (SE))) {
286
+ slim = arL->getStart ();
287
+ continue ;
288
+ } else if (SE.isKnownNonPositive (
289
+ arL->getStepRecurrence (SE))) {
290
+ #if LLVM_VERSION_MAJOR >= 12
291
+ auto bd =
292
+ SE.getSymbolicMaxBackedgeTakenCount (arL->getLoop ());
293
+ #else
294
+ auto bd = SE.getBackedgeTakenCount (arL->getLoop ());
295
+ #endif
296
+ if (bd == SE.getCouldNotCompute ())
297
+ break ;
298
+ slim = arL->evaluateAtIteration (bd, SE);
299
+ continue ;
300
+ }
301
+ }
302
+ break ;
303
+ }
304
+
305
+ if (auto arL = dyn_cast<SCEVAddRecExpr>(lim)) {
306
+ if (SE.isKnownNonNegative (arL->getStepRecurrence (SE))) {
307
+ #if LLVM_VERSION_MAJOR >= 12
308
+ auto bd = SE.getSymbolicMaxBackedgeTakenCount (arL->getLoop ());
309
+ #else
310
+ auto bd = SE.getBackedgeTakenCount (arL->getLoop ());
311
+ #endif
312
+ if (bd == SE.getCouldNotCompute ())
313
+ break ;
314
+ lim = arL->evaluateAtIteration (bd, SE);
315
+ continue ;
316
+ } else if (SE.isKnownNonPositive (arL->getStepRecurrence (SE))) {
317
+ lim = arL->getStart ();
318
+ continue ;
319
+ }
320
+ }
321
+ break ;
322
+ }
323
+ for (auto lim = LS; lim != SE.getCouldNotCompute ();) {
324
+ // [S, S+Size][start load, L+Size]
325
+ for (auto slim = SS; slim != SE.getCouldNotCompute ();) {
326
+ auto lsub = SE.getMinusSCEV (lim, SE.getAddExpr (slim, TS));
327
+ // llvm::errs() << " $$$ " << *lsub << "|" << *slim << "|" <<
328
+ // *lim << "\n";
329
+ if (SE.isKnownNonNegative (lsub)) {
330
+ return false ;
331
+ }
332
+ if (auto arL = dyn_cast<SCEVAddRecExpr>(slim)) {
333
+ if (SE.isKnownNonNegative (arL->getStepRecurrence (SE))) {
334
+ #if LLVM_VERSION_MAJOR >= 12
335
+ auto bd =
336
+ SE.getSymbolicMaxBackedgeTakenCount (arL->getLoop ());
337
+ #else
338
+ auto bd = SE.getBackedgeTakenCount (arL->getLoop ());
339
+ #endif
340
+ if (bd == SE.getCouldNotCompute ())
341
+ break ;
342
+ slim = arL->evaluateAtIteration (bd, SE);
343
+ continue ;
344
+ } else if (SE.isKnownNonPositive (
345
+ arL->getStepRecurrence (SE))) {
346
+ slim = arL->getStart ();
347
+ continue ;
348
+ }
349
+ }
350
+ break ;
351
+ }
352
+
353
+ if (auto arL = dyn_cast<SCEVAddRecExpr>(lim)) {
354
+ if (SE.isKnownNonNegative (arL->getStepRecurrence (SE))) {
355
+ lim = arL->getStart ();
356
+ continue ;
357
+ } else if (SE.isKnownNonPositive (arL->getStepRecurrence (SE))) {
358
+ #if LLVM_VERSION_MAJOR >= 12
359
+ auto bd = SE.getSymbolicMaxBackedgeTakenCount (arL->getLoop ());
360
+ #else
361
+ auto bd = SE.getBackedgeTakenCount (arL->getLoop ());
362
+ #endif
363
+ if (bd == SE.getCouldNotCompute ())
364
+ break ;
365
+ lim = arL->evaluateAtIteration (bd, SE);
366
+ continue ;
367
+ }
368
+ }
369
+ break ;
370
+ }
371
+ }
372
+ }
373
+
255
374
if (auto II = dyn_cast<IntrinsicInst>(inst2)) {
256
375
if (II->getIntrinsicID () == Intrinsic::nvvm_barrier0 ||
257
376
II->getIntrinsicID () == Intrinsic::amdgcn_s_barrier) {
@@ -1309,9 +1428,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
1309
1428
for (auto &I : *BB)
1310
1429
unnecessaryInstructionsTmp.insert (&I);
1311
1430
}
1312
- CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc , gutils->OrigLI ,
1313
- gutils->OrigDT , TLI, unnecessaryInstructionsTmp,
1314
- _uncacheable_argsPP,
1431
+ CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc ,
1432
+ PPC.FAM .getResult <ScalarEvolutionAnalysis>(*gutils->oldFunc ),
1433
+ gutils->OrigLI , gutils->OrigDT , TLI,
1434
+ unnecessaryInstructionsTmp, _uncacheable_argsPP,
1315
1435
/* topLevel*/ false );
1316
1436
const std::map<CallInst *, const std::map<Argument *, bool >>
1317
1437
uncacheable_args_map = CA.compute_uncacheable_args_for_callsites ();
@@ -2434,9 +2554,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2434
2554
for (auto &I : *BB)
2435
2555
unnecessaryInstructionsTmp.insert (&I);
2436
2556
}
2437
- CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc , gutils->OrigLI ,
2438
- gutils->OrigDT , TLI, unnecessaryInstructionsTmp,
2439
- _uncacheable_argsPP, topLevel);
2557
+ CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc ,
2558
+ PPC.FAM .getResult <ScalarEvolutionAnalysis>(*gutils->oldFunc ),
2559
+ gutils->OrigLI , gutils->OrigDT , TLI,
2560
+ unnecessaryInstructionsTmp, _uncacheable_argsPP, topLevel);
2440
2561
const std::map<CallInst *, const std::map<Argument *, bool >>
2441
2562
uncacheable_args_map =
2442
2563
(augmenteddata) ? augmenteddata->uncacheable_args_map
0 commit comments