Skip to content

Commit a959be9

Browse files
wsmosesZuseZ4
andauthored
Enable PHI node improvements in strict aliasing (rust-lang#760)
Co-authored-by: Manuel Drehwald <[email protected]>
1 parent 05f1c8b commit a959be9

File tree

2 files changed

+161
-112
lines changed

2 files changed

+161
-112
lines changed

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 117 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,135 +1554,140 @@ void TypeAnalyzer::visitPHINode(PHINode &phi) {
15541554
assert(phi.getNumIncomingValues() > 0);
15551555

15561556
// TODO generalize this (and for recursive, etc)
1557-
std::deque<Value *> vals;
1558-
std::set<Value *> seen{&phi};
1559-
for (auto &op : phi.incoming_values()) {
1560-
vals.push_back(op);
1561-
}
1562-
1563-
SmallVector<BinaryOperator *, 4> bos;
1564-
1565-
// Unique values that propagate into this phi
1566-
SmallVector<Value *, 4> UniqueValues;
15671557

1568-
while (vals.size()) {
1569-
Value *todo = vals.front();
1570-
vals.pop_front();
1558+
for (int i = 0; i < 2; i++) {
15711559

1572-
if (auto bo = dyn_cast<BinaryOperator>(todo)) {
1573-
if (bo->getOpcode() == BinaryOperator::Add) {
1574-
if (isa<Constant>(bo->getOperand(0))) {
1575-
bos.push_back(bo);
1576-
todo = bo->getOperand(1);
1577-
}
1578-
if (isa<Constant>(bo->getOperand(1))) {
1579-
bos.push_back(bo);
1580-
todo = bo->getOperand(0);
1581-
}
1582-
}
1560+
std::deque<Value *> vals;
1561+
std::set<Value *> seen{&phi};
1562+
for (auto &op : phi.incoming_values()) {
1563+
vals.push_back(op);
15831564
}
1565+
SmallVector<BinaryOperator *, 4> bos;
15841566

1585-
if (seen.count(todo))
1586-
continue;
1587-
seen.insert(todo);
1588-
1589-
if (auto nphi = dyn_cast<PHINode>(todo)) {
1590-
for (auto &op : nphi->incoming_values()) {
1591-
vals.push_back(op);
1592-
}
1593-
continue;
1594-
}
1595-
if (auto sel = dyn_cast<SelectInst>(todo)) {
1596-
vals.push_back(sel->getOperand(1));
1597-
vals.push_back(sel->getOperand(2));
1598-
continue;
1599-
}
1600-
UniqueValues.push_back(todo);
1601-
}
1567+
// Unique values that propagate into this phi
1568+
SmallVector<Value *, 4> UniqueValues;
16021569

1603-
TypeTree PhiTypes;
1604-
bool set = false;
1570+
while (vals.size()) {
1571+
Value *todo = vals.front();
1572+
vals.pop_front();
16051573

1606-
for (size_t i = 0, size = UniqueValues.size(); i < size; ++i) {
1607-
TypeTree newData = getAnalysis(UniqueValues[i]);
1608-
if (UniqueValues.size() == 2) {
1609-
if (auto BO = dyn_cast<BinaryOperator>(UniqueValues[i])) {
1610-
if (BO->getOpcode() == BinaryOperator::Add ||
1611-
BO->getOpcode() == BinaryOperator::Mul) {
1612-
TypeTree otherData = getAnalysis(UniqueValues[1 - i]);
1613-
// If we are adding/muling to a constant to derive this, we can assume
1614-
// it to be an integer rather than Anything
1615-
if (isa<Constant>(UniqueValues[1 - i])) {
1616-
otherData = TypeTree(BaseType::Integer).Only(-1);
1574+
if (auto bo = dyn_cast<BinaryOperator>(todo)) {
1575+
if (bo->getOpcode() == BinaryOperator::Add) {
1576+
if (isa<Constant>(bo->getOperand(0))) {
1577+
bos.push_back(bo);
1578+
todo = bo->getOperand(1);
16171579
}
1618-
if (BO->getOperand(0) == &phi) {
1619-
set = true;
1620-
PhiTypes = otherData;
1621-
PhiTypes.binopIn(getAnalysis(BO->getOperand(1)), BO->getOpcode());
1622-
break;
1623-
} else if (BO->getOperand(1) == &phi) {
1624-
set = true;
1625-
PhiTypes = getAnalysis(BO->getOperand(0));
1626-
PhiTypes.binopIn(otherData, BO->getOpcode());
1627-
break;
1580+
if (isa<Constant>(bo->getOperand(1))) {
1581+
bos.push_back(bo);
1582+
todo = bo->getOperand(0);
16281583
}
1629-
} else if (BO->getOpcode() == BinaryOperator::Sub) {
1630-
// Repeated subtraction from a type X yields the type X back
1631-
TypeTree otherData = getAnalysis(UniqueValues[1 - i]);
1632-
// If we are subtracting from a constant to derive this, we can assume
1633-
// it to be an integer rather than Anything
1634-
if (isa<Constant>(UniqueValues[1 - i])) {
1635-
otherData = TypeTree(BaseType::Integer).Only(-1);
1584+
}
1585+
}
1586+
1587+
if (seen.count(todo))
1588+
continue;
1589+
seen.insert(todo);
1590+
1591+
if (auto nphi = dyn_cast<PHINode>(todo)) {
1592+
if (i == 0) {
1593+
for (auto &op : nphi->incoming_values()) {
1594+
vals.push_back(op);
16361595
}
1637-
if (BO->getOperand(0) == &phi) {
1638-
set = true;
1639-
PhiTypes = otherData;
1640-
break;
1596+
continue;
1597+
}
1598+
}
1599+
if (auto sel = dyn_cast<SelectInst>(todo)) {
1600+
vals.push_back(sel->getOperand(1));
1601+
vals.push_back(sel->getOperand(2));
1602+
continue;
1603+
}
1604+
UniqueValues.push_back(todo);
1605+
}
1606+
1607+
TypeTree PhiTypes;
1608+
bool set = false;
1609+
1610+
for (size_t i = 0, size = UniqueValues.size(); i < size; ++i) {
1611+
TypeTree newData = getAnalysis(UniqueValues[i]);
1612+
if (UniqueValues.size() == 2) {
1613+
if (auto BO = dyn_cast<BinaryOperator>(UniqueValues[i])) {
1614+
if (BO->getOpcode() == BinaryOperator::Add ||
1615+
BO->getOpcode() == BinaryOperator::Mul) {
1616+
TypeTree otherData = getAnalysis(UniqueValues[1 - i]);
1617+
// If we are adding/muling to a constant to derive this, we can
1618+
// assume it to be an integer rather than Anything
1619+
if (isa<Constant>(UniqueValues[1 - i])) {
1620+
otherData = TypeTree(BaseType::Integer).Only(-1);
1621+
}
1622+
if (BO->getOperand(0) == &phi) {
1623+
set = true;
1624+
PhiTypes = otherData;
1625+
PhiTypes.binopIn(getAnalysis(BO->getOperand(1)), BO->getOpcode());
1626+
break;
1627+
} else if (BO->getOperand(1) == &phi) {
1628+
set = true;
1629+
PhiTypes = getAnalysis(BO->getOperand(0));
1630+
PhiTypes.binopIn(otherData, BO->getOpcode());
1631+
break;
1632+
}
1633+
} else if (BO->getOpcode() == BinaryOperator::Sub) {
1634+
// Repeated subtraction from a type X yields the type X back
1635+
TypeTree otherData = getAnalysis(UniqueValues[1 - i]);
1636+
// If we are subtracting from a constant to derive this, we can
1637+
// assume it to be an integer rather than Anything
1638+
if (isa<Constant>(UniqueValues[1 - i])) {
1639+
otherData = TypeTree(BaseType::Integer).Only(-1);
1640+
}
1641+
if (BO->getOperand(0) == &phi) {
1642+
set = true;
1643+
PhiTypes = otherData;
1644+
break;
1645+
}
16411646
}
16421647
}
16431648
}
1649+
if (set) {
1650+
PhiTypes &= newData;
1651+
// TODO consider the or of anything (see selectinst)
1652+
// however, this cannot be done yet for risk of turning
1653+
// phi's that add floats into anything
1654+
// PhiTypes |= newData.JustAnything();
1655+
} else {
1656+
set = true;
1657+
PhiTypes = newData;
1658+
}
16441659
}
1645-
if (set) {
1646-
PhiTypes &= newData;
1647-
// TODO consider the or of anything (see selectinst)
1648-
// however, this cannot be done yet for risk of turning
1649-
// phi's that add floats into anything
1650-
// PhiTypes |= newData.JustAnything();
1651-
} else {
1652-
set = true;
1653-
PhiTypes = newData;
1654-
}
1655-
}
16561660

1657-
assert(set);
1658-
// If we are only add / sub / etc to derive a value based off 0
1659-
// we can start by assuming the type of 0 is integer rather
1660-
// than assuming it could be anything (per null)
1661-
if (bos.size() > 0 && UniqueValues.size() == 1 &&
1662-
isa<ConstantInt>(UniqueValues[0]) &&
1663-
(cast<ConstantInt>(UniqueValues[0])->isZero() ||
1664-
cast<ConstantInt>(UniqueValues[0])->isOne())) {
1665-
PhiTypes = TypeTree(BaseType::Integer).Only(-1);
1666-
}
1667-
for (BinaryOperator *bo : bos) {
1668-
TypeTree vd1 = isa<Constant>(bo->getOperand(0))
1669-
? getAnalysis(bo->getOperand(0)).Data0()
1670-
: PhiTypes.Data0();
1671-
TypeTree vd2 = isa<Constant>(bo->getOperand(1))
1672-
? getAnalysis(bo->getOperand(1)).Data0()
1673-
: PhiTypes.Data0();
1674-
vd1.binopIn(vd2, bo->getOpcode());
1675-
PhiTypes &= vd1.Only(bo->getType()->isIntegerTy() ? -1 : 0);
1676-
}
1661+
assert(set);
1662+
// If we are only add / sub / etc to derive a value based off 0
1663+
// we can start by assuming the type of 0 is integer rather
1664+
// than assuming it could be anything (per null)
1665+
if (bos.size() > 0 && UniqueValues.size() == 1 &&
1666+
isa<ConstantInt>(UniqueValues[0]) &&
1667+
(cast<ConstantInt>(UniqueValues[0])->isZero() ||
1668+
cast<ConstantInt>(UniqueValues[0])->isOne())) {
1669+
PhiTypes = TypeTree(BaseType::Integer).Only(-1);
1670+
}
1671+
for (BinaryOperator *bo : bos) {
1672+
TypeTree vd1 = isa<Constant>(bo->getOperand(0))
1673+
? getAnalysis(bo->getOperand(0)).Data0()
1674+
: PhiTypes.Data0();
1675+
TypeTree vd2 = isa<Constant>(bo->getOperand(1))
1676+
? getAnalysis(bo->getOperand(1)).Data0()
1677+
: PhiTypes.Data0();
1678+
vd1.binopIn(vd2, bo->getOpcode());
1679+
PhiTypes &= vd1.Only(bo->getType()->isIntegerTy() ? -1 : 0);
1680+
}
16771681

1678-
if (direction & DOWN) {
1679-
if (phi.getType()->isIntOrIntVectorTy() &&
1680-
PhiTypes.Inner0() == BaseType::Anything) {
1681-
if (mustRemainInteger(&phi)) {
1682-
PhiTypes = TypeTree(BaseType::Integer).Only(-1);
1682+
if (direction & DOWN) {
1683+
if (phi.getType()->isIntOrIntVectorTy() &&
1684+
PhiTypes.Inner0() == BaseType::Anything) {
1685+
if (mustRemainInteger(&phi)) {
1686+
PhiTypes = TypeTree(BaseType::Integer).Only(-1);
1687+
}
16831688
}
1689+
updateAnalysis(&phi, PhiTypes, &phi);
16841690
}
1685-
updateAnalysis(&phi, PhiTypes, &phi);
16861691
}
16871692
}
16881693

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=f -enzyme-strict-aliasing=0 -o /dev/null | FileCheck %s
2+
3+
declare i8* @_Znwm(i64)
4+
5+
define void @f() {
6+
e:
7+
%i78 = call noalias nonnull i8* @_Znwm(i64 8)
8+
br label %bb155
9+
10+
bb155:
11+
%i159 = phi i8* [ %i78, %e ], [ %i220, %bb216 ]
12+
%l = load i8, i8* %i159, align 1
13+
br i1 true, label %bb179, label %bb216
14+
15+
bb179:
16+
%i192 = call noalias nonnull i8* @_Znwm(i64 8)
17+
br label %bb216
18+
19+
bb216:
20+
%i217 = phi i8* [ %i192, %bb179 ], [ %i159, %bb155 ]
21+
%i220 = getelementptr inbounds i8, i8* %i217, i64 1
22+
br i1 true, label %bb153, label %bb155
23+
24+
bb153: ; preds = %bb216
25+
ret void
26+
}
27+
28+
; CHECK: f - {} |
29+
; CHECK-NEXT: e
30+
; CHECK-NEXT: %i78 = call noalias nonnull i8* @_Znwm(i64 8): {[-1]:Pointer, [-1,0]:Integer}
31+
; CHECK-NEXT: br label %bb155: {}
32+
; CHECK-NEXT: bb155
33+
; CHECK-NEXT: %i159 = phi i8* [ %i78, %e ], [ %i220, %bb216 ]: {[-1]:Pointer, [-1,0]:Integer}
34+
; CHECK-NEXT: %l = load i8, i8* %i159, align 1: {[-1]:Integer}
35+
; CHECK-NEXT: br i1 true, label %bb179, label %bb216: {}
36+
; CHECK-NEXT: bb179
37+
; CHECK-NEXT: %i192 = call noalias nonnull i8* @_Znwm(i64 8): {[-1]:Pointer}
38+
; CHECK-NEXT: br label %bb216: {}
39+
; CHECK-NEXT: bb216
40+
; CHECK-NEXT: %i217 = phi i8* [ %i192, %bb179 ], [ %i159, %bb155 ]: {[-1]:Pointer}
41+
; CHECK-NEXT: %i220 = getelementptr inbounds i8, i8* %i217, i64 1: {[-1]:Pointer}
42+
; CHECK-NEXT: br i1 true, label %bb153, label %bb155: {}
43+
; CHECK-NEXT: bb153
44+
; CHECK-NEXT: ret void: {}

0 commit comments

Comments
 (0)