Skip to content

Commit 57e0931

Browse files
author
Liren Peng
committed
[ScalarEvolution] Infer loop max trip count from array accesses
Data references in a loop should not access elements over the statically allocated size. So we can infer a loop max trip count from this undefined behavior. Reviewed By: reames, mkazantsev, nikic Differential Revision: https://reviews.llvm.org/D109821
1 parent 8f10197 commit 57e0931

File tree

3 files changed

+340
-0
lines changed

3 files changed

+340
-0
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

+7
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,13 @@ class ScalarEvolution {
793793
/// Returns 0 if the trip count is unknown or not constant.
794794
unsigned getSmallConstantMaxTripCount(const Loop *L);
795795

796+
/// Returns the upper bound of the loop trip count infered from array size.
797+
/// Can not access bytes starting outside the statically allocated size
798+
/// without being immediate UB.
799+
/// Returns SCEVCouldNotCompute if the trip count could not inferred
800+
/// from array accesses.
801+
const SCEV *getConstantMaxTripCountFromArray(const Loop *L);
802+
796803
/// Returns the largest constant divisor of the trip count as a normal
797804
/// unsigned value, if possible. This means that the actual trip count is
798805
/// always a multiple of the returned value. Returns 1 if the trip count is

llvm/lib/Analysis/ScalarEvolution.cpp

+125
Original file line numberDiff line numberDiff line change
@@ -7269,6 +7269,131 @@ unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
72697269
return getConstantTripCount(MaxExitCount);
72707270
}
72717271

7272+
const SCEV *ScalarEvolution::getConstantMaxTripCountFromArray(const Loop *L) {
7273+
// We can't infer from Array in Irregular Loop.
7274+
// FIXME: It's hard to infer loop bound from array operated in Nested Loop.
7275+
if (!L->isLoopSimplifyForm() || !L->isInnermost())
7276+
return getCouldNotCompute();
7277+
7278+
// FIXME: To make the scene more typical, we only analysis loops that have
7279+
// one exiting block and that block must be the latch. To make it easier to
7280+
// capture loops that have memory access and memory access will be executed
7281+
// in each iteration.
7282+
const BasicBlock *LoopLatch = L->getLoopLatch();
7283+
assert(LoopLatch && "See defination of simplify form loop.");
7284+
if (L->getExitingBlock() != LoopLatch)
7285+
return getCouldNotCompute();
7286+
7287+
const DataLayout &DL = getDataLayout();
7288+
SmallVector<const SCEV *> InferCountColl;
7289+
for (auto *BB : L->getBlocks()) {
7290+
// Go here, we can know that Loop is a single exiting and simplified form
7291+
// loop. Make sure that infer from Memory Operation in those BBs must be
7292+
// executed in loop. First step, we can make sure that max execution time
7293+
// of MemAccessBB in loop represents latch max excution time.
7294+
// If MemAccessBB does not dom Latch, skip.
7295+
// Entry
7296+
// │
7297+
// ┌─────▼─────┐
7298+
// │Loop Header◄─────┐
7299+
// └──┬──────┬─┘ │
7300+
// │ │ │
7301+
// ┌────────▼──┐ ┌─▼─────┐ │
7302+
// │MemAccessBB│ │OtherBB│ │
7303+
// └────────┬──┘ └─┬─────┘ │
7304+
// │ │ │
7305+
// ┌─▼──────▼─┐ │
7306+
// │Loop Latch├─────┘
7307+
// └────┬─────┘
7308+
// ▼
7309+
// Exit
7310+
if (!DT.dominates(BB, LoopLatch))
7311+
continue;
7312+
7313+
for (Instruction &Inst : *BB) {
7314+
// Find Memory Operation Instruction.
7315+
auto *GEP = getLoadStorePointerOperand(&Inst);
7316+
if (!GEP)
7317+
continue;
7318+
7319+
auto *ElemSize = dyn_cast<SCEVConstant>(getElementSize(&Inst));
7320+
// Do not infer from scalar type, eg."ElemSize = sizeof()".
7321+
if (!ElemSize)
7322+
continue;
7323+
7324+
// Use a existing polynomial recurrence on the trip count.
7325+
auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(GEP));
7326+
if (!AddRec)
7327+
continue;
7328+
auto *ArrBase = dyn_cast<SCEVUnknown>(getPointerBase(AddRec));
7329+
auto *Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(*this));
7330+
if (!ArrBase || !Step)
7331+
continue;
7332+
assert(isLoopInvariant(ArrBase, L) && "See addrec definition");
7333+
7334+
// Only handle { %array + step },
7335+
// FIXME: {(SCEVAddRecExpr) + step } could not be analysed here.
7336+
if (AddRec->getStart() != ArrBase)
7337+
continue;
7338+
7339+
// Memory operation pattern which have gaps.
7340+
// Or repeat memory opreation.
7341+
// And index of GEP wraps arround.
7342+
if (Step->getAPInt().getActiveBits() > 32 ||
7343+
Step->getAPInt().getZExtValue() !=
7344+
ElemSize->getAPInt().getZExtValue() ||
7345+
Step->isZero() || Step->getAPInt().isNegative())
7346+
continue;
7347+
7348+
// Only infer from stack array which has certain size.
7349+
// Make sure alloca instruction is not excuted in loop.
7350+
AllocaInst *AllocateInst = dyn_cast<AllocaInst>(ArrBase->getValue());
7351+
if (!AllocateInst || L->contains(AllocateInst->getParent()))
7352+
continue;
7353+
7354+
// Make sure only handle normal array.
7355+
auto *Ty = dyn_cast<ArrayType>(AllocateInst->getAllocatedType());
7356+
auto *ArrSize = dyn_cast<ConstantInt>(AllocateInst->getArraySize());
7357+
if (!Ty || !ArrSize || !ArrSize->isOne())
7358+
continue;
7359+
// Also make sure step was increased the same with sizeof allocated
7360+
// element type.
7361+
const PointerType *GEPT = dyn_cast<PointerType>(GEP->getType());
7362+
if (Ty->getElementType() != GEPT->getElementType())
7363+
continue;
7364+
7365+
// FIXME: Since gep indices are silently zext to the indexing type,
7366+
// we will have a narrow gep index which wraps around rather than
7367+
// increasing strictly, we shoule ensure that step is increasing
7368+
// strictly by the loop iteration.
7369+
// Now we can infer a max execution time by MemLength/StepLength.
7370+
const SCEV *MemSize =
7371+
getConstant(Step->getType(), DL.getTypeAllocSize(Ty));
7372+
auto *MaxExeCount =
7373+
dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
7374+
if (!MaxExeCount || MaxExeCount->getAPInt().getActiveBits() > 32)
7375+
continue;
7376+
7377+
// If the loop reaches the maximum number of executions, we can not
7378+
// access bytes starting outside the statically allocated size without
7379+
// being immediate UB. But it is allowed to enter loop header one more
7380+
// time.
7381+
auto *InferCount = dyn_cast<SCEVConstant>(
7382+
getAddExpr(MaxExeCount, getOne(MaxExeCount->getType())));
7383+
// Discard the maximum number of execution times under 32bits.
7384+
if (!InferCount || InferCount->getAPInt().getActiveBits() > 32)
7385+
continue;
7386+
7387+
InferCountColl.push_back(InferCount);
7388+
}
7389+
}
7390+
7391+
if (InferCountColl.size() == 0)
7392+
return getCouldNotCompute();
7393+
7394+
return getUMinFromMismatchedTypes(InferCountColl);
7395+
}
7396+
72727397
unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
72737398
SmallVector<BasicBlock *, 8> ExitingBlocks;
72747399
L->getExitingBlocks(ExitingBlocks);

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

+208
Original file line numberDiff line numberDiff line change
@@ -1538,4 +1538,212 @@ TEST_F(ScalarEvolutionsTest, SCEVUDivFloorCeiling) {
15381538
});
15391539
}
15401540

1541+
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromArrayNormal) {
1542+
LLVMContext C;
1543+
SMDiagnostic Err;
1544+
std::unique_ptr<Module> M = parseAssemblyString(
1545+
"define void @foo(i32 signext %len) { "
1546+
"entry: "
1547+
" %a = alloca [7 x i32], align 4 "
1548+
" %cmp4 = icmp sgt i32 %len, 0 "
1549+
" br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup "
1550+
"for.body.preheader: "
1551+
" br label %for.body "
1552+
"for.cond.cleanup.loopexit: "
1553+
" br label %for.cond.cleanup "
1554+
"for.cond.cleanup: "
1555+
" ret void "
1556+
"for.body: "
1557+
" %iv = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ] "
1558+
" %idxprom = zext i32 %iv to i64 "
1559+
" %arrayidx = getelementptr inbounds [7 x i32], [7 x i32]* %a, i64 0, \
1560+
i64 %idxprom "
1561+
" store i32 0, i32* %arrayidx, align 4 "
1562+
" %inc = add nuw nsw i32 %iv, 1 "
1563+
" %cmp = icmp slt i32 %inc, %len "
1564+
" br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit "
1565+
"} ",
1566+
Err, C);
1567+
1568+
ASSERT_TRUE(M && "Could not parse module?");
1569+
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
1570+
1571+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1572+
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
1573+
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
1574+
1575+
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
1576+
EXPECT_FALSE(isa<SCEVCouldNotCompute>(ITC));
1577+
EXPECT_TRUE(isa<SCEVConstant>(ITC));
1578+
EXPECT_EQ(cast<SCEVConstant>(ITC)->getAPInt().getSExtValue(), 8);
1579+
});
1580+
}
1581+
1582+
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromZeroArray) {
1583+
LLVMContext C;
1584+
SMDiagnostic Err;
1585+
std::unique_ptr<Module> M = parseAssemblyString(
1586+
"define void @foo(i32 signext %len) { "
1587+
"entry: "
1588+
" %a = alloca [0 x i32], align 4 "
1589+
" %cmp4 = icmp sgt i32 %len, 0 "
1590+
" br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup "
1591+
"for.body.preheader: "
1592+
" br label %for.body "
1593+
"for.cond.cleanup.loopexit: "
1594+
" br label %for.cond.cleanup "
1595+
"for.cond.cleanup: "
1596+
" ret void "
1597+
"for.body: "
1598+
" %iv = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ] "
1599+
" %idxprom = zext i32 %iv to i64 "
1600+
" %arrayidx = getelementptr inbounds [0 x i32], [0 x i32]* %a, i64 0, \
1601+
i64 %idxprom "
1602+
" store i32 0, i32* %arrayidx, align 4 "
1603+
" %inc = add nuw nsw i32 %iv, 1 "
1604+
" %cmp = icmp slt i32 %inc, %len "
1605+
" br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit "
1606+
"} ",
1607+
Err, C);
1608+
1609+
ASSERT_TRUE(M && "Could not parse module?");
1610+
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
1611+
1612+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1613+
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
1614+
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
1615+
1616+
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
1617+
EXPECT_FALSE(isa<SCEVCouldNotCompute>(ITC));
1618+
EXPECT_TRUE(isa<SCEVConstant>(ITC));
1619+
EXPECT_EQ(cast<SCEVConstant>(ITC)->getAPInt().getSExtValue(), 1);
1620+
});
1621+
}
1622+
1623+
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromExtremArray) {
1624+
LLVMContext C;
1625+
SMDiagnostic Err;
1626+
std::unique_ptr<Module> M = parseAssemblyString(
1627+
"define void @foo(i32 signext %len) { "
1628+
"entry: "
1629+
" %a = alloca [4294967295 x i1], align 4 "
1630+
" %cmp4 = icmp sgt i32 %len, 0 "
1631+
" br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup "
1632+
"for.body.preheader: "
1633+
" br label %for.body "
1634+
"for.cond.cleanup.loopexit: "
1635+
" br label %for.cond.cleanup "
1636+
"for.cond.cleanup: "
1637+
" ret void "
1638+
"for.body: "
1639+
" %iv = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ] "
1640+
" %idxprom = zext i32 %iv to i64 "
1641+
" %arrayidx = getelementptr inbounds [4294967295 x i1], \
1642+
[4294967295 x i1]* %a, i64 0, i64 %idxprom "
1643+
" store i1 0, i1* %arrayidx, align 4 "
1644+
" %inc = add nuw nsw i32 %iv, 1 "
1645+
" %cmp = icmp slt i32 %inc, %len "
1646+
" br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit "
1647+
"} ",
1648+
Err, C);
1649+
1650+
ASSERT_TRUE(M && "Could not parse module?");
1651+
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
1652+
1653+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1654+
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
1655+
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
1656+
1657+
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
1658+
EXPECT_TRUE(isa<SCEVCouldNotCompute>(ITC));
1659+
});
1660+
}
1661+
1662+
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromArrayInBranch) {
1663+
LLVMContext C;
1664+
SMDiagnostic Err;
1665+
std::unique_ptr<Module> M = parseAssemblyString(
1666+
"define void @foo(i32 signext %len) { "
1667+
"entry: "
1668+
" %a = alloca [8 x i32], align 4 "
1669+
" br label %for.cond "
1670+
"for.cond: "
1671+
" %iv = phi i32 [ %inc, %for.inc ], [ 0, %entry ] "
1672+
" %cmp = icmp slt i32 %iv, %len "
1673+
" br i1 %cmp, label %for.body, label %for.cond.cleanup "
1674+
"for.cond.cleanup: "
1675+
" br label %for.end "
1676+
"for.body: "
1677+
" %cmp1 = icmp slt i32 %iv, 8 "
1678+
" br i1 %cmp1, label %if.then, label %if.end "
1679+
"if.then: "
1680+
" %idxprom = sext i32 %iv to i64 "
1681+
" %arrayidx = getelementptr inbounds [8 x i32], [8 x i32]* %a, i64 0, \
1682+
i64 %idxprom "
1683+
" store i32 0, i32* %arrayidx, align 4 "
1684+
" br label %if.end "
1685+
"if.end: "
1686+
" br label %for.inc "
1687+
"for.inc: "
1688+
" %inc = add nsw i32 %iv, 1 "
1689+
" br label %for.cond "
1690+
"for.end: "
1691+
" ret void "
1692+
"} ",
1693+
Err, C);
1694+
1695+
ASSERT_TRUE(M && "Could not parse module?");
1696+
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
1697+
1698+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1699+
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
1700+
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
1701+
1702+
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
1703+
EXPECT_TRUE(isa<SCEVCouldNotCompute>(ITC));
1704+
});
1705+
}
1706+
1707+
TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromMultiDemArray) {
1708+
LLVMContext C;
1709+
SMDiagnostic Err;
1710+
std::unique_ptr<Module> M = parseAssemblyString(
1711+
"define void @foo(i32 signext %len) { "
1712+
"entry: "
1713+
" %a = alloca [3 x [5 x i32]], align 4 "
1714+
" br label %for.cond "
1715+
"for.cond: "
1716+
" %iv = phi i32 [ %inc, %for.inc ], [ 0, %entry ] "
1717+
" %cmp = icmp slt i32 %iv, %len "
1718+
" br i1 %cmp, label %for.body, label %for.cond.cleanup "
1719+
"for.cond.cleanup: "
1720+
" br label %for.end "
1721+
"for.body: "
1722+
" %arrayidx = getelementptr inbounds [3 x [5 x i32]], \
1723+
[3 x [5 x i32]]* %a, i64 0, i64 3 "
1724+
" %idxprom = sext i32 %iv to i64 "
1725+
" %arrayidx1 = getelementptr inbounds [5 x i32], [5 x i32]* %arrayidx, \
1726+
i64 0, i64 %idxprom "
1727+
" store i32 0, i32* %arrayidx1, align 4"
1728+
" br label %for.inc "
1729+
"for.inc: "
1730+
" %inc = add nsw i32 %iv, 1 "
1731+
" br label %for.cond "
1732+
"for.end: "
1733+
" ret void "
1734+
"} ",
1735+
Err, C);
1736+
1737+
ASSERT_TRUE(M && "Could not parse module?");
1738+
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
1739+
1740+
runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
1741+
auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv"));
1742+
const Loop *L = cast<SCEVAddRecExpr>(ScevIV)->getLoop();
1743+
1744+
const SCEV *ITC = SE.getConstantMaxTripCountFromArray(L);
1745+
EXPECT_TRUE(isa<SCEVCouldNotCompute>(ITC));
1746+
});
1747+
}
1748+
15411749
} // end namespace llvm

0 commit comments

Comments
 (0)