Skip to content

Commit aad739c

Browse files
authored
Improve Activity Analysis (rust-lang#614)
* Improve activity analysis * Fix hypothesis * Fix doubleload * Fix tests
1 parent 875b2ae commit aad739c

File tree

11 files changed

+294
-43
lines changed

11 files changed

+294
-43
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 129 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,8 +1082,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
10821082
if (!arg->hasByValAttr()) {
10831083
bool res = isConstantValue(TR, TmpOrig);
10841084
if (res) {
1085+
if (EnzymePrintActivity)
1086+
llvm::errs() << " arg const from orig val=" << *Val
1087+
<< " orig=" << *TmpOrig << "\n";
10851088
InsertConstantValue(TR, Val);
10861089
} else {
1090+
if (EnzymePrintActivity)
1091+
llvm::errs() << " arg active from orig val=" << *Val
1092+
<< " orig=" << *TmpOrig << "\n";
10871093
ActiveValues.insert(Val);
10881094
}
10891095
return res;
@@ -1096,7 +1102,29 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
10961102

10971103
// If our origin is a load of a known inactive (say inactive argument), we
10981104
// are also inactive
1099-
if (auto LI = dyn_cast<LoadInst>(TmpOrig)) {
1105+
if (auto PN = dyn_cast<PHINode>(TmpOrig)) {
1106+
// Not taking fast path incase phi is recursive.
1107+
Value *active = nullptr;
1108+
for (auto &V : PN->incoming_values()) {
1109+
if (!UpHypothesis->isConstantValue(TR, V.get())) {
1110+
active = V.get();
1111+
break;
1112+
}
1113+
}
1114+
if (!active) {
1115+
InsertConstantValue(TR, Val);
1116+
if (TmpOrig != Val) {
1117+
InsertConstantValue(TR, TmpOrig);
1118+
}
1119+
insertConstantsFrom(TR, *UpHypothesis);
1120+
return true;
1121+
} else {
1122+
ReEvaluateValueIfInactiveValue[active].insert(Val);
1123+
if (TmpOrig != Val) {
1124+
ReEvaluateValueIfInactiveValue[active].insert(TmpOrig);
1125+
}
1126+
}
1127+
} else if (auto LI = dyn_cast<LoadInst>(TmpOrig)) {
11001128

11011129
if (directions == UP) {
11021130
if (isConstantValue(TR, LI->getPointerOperand())) {
@@ -1277,24 +1305,31 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
12771305
// A pointer value is active if two things hold:
12781306
// an potentially active value is stored into the memory
12791307
// memory loaded from the value is used in an active way
1308+
bool potentiallyActiveStore = false;
12801309
bool potentialStore = false;
12811310
bool potentiallyActiveLoad = false;
12821311

1283-
if (isa<Instruction>(Val) || isa<Argument>(Val)) {
1284-
// These are handled by iterating through all
1285-
} else {
1286-
llvm::errs() << "unknown pointer value type: " << *Val << "\n";
1287-
assert(0 && "unknown pointer value type");
1288-
llvm_unreachable("unknown pointer value type");
1289-
}
1290-
12911312
// Assume the value (not instruction) is itself active
12921313
// In spite of that can we show that there are either no active stores
12931314
// or no active loads
12941315
std::shared_ptr<ActivityAnalyzer> Hypothesis =
12951316
std::shared_ptr<ActivityAnalyzer>(
12961317
new ActivityAnalyzer(*this, directions));
12971318
Hypothesis->ActiveValues.insert(Val);
1319+
if (auto VI = dyn_cast<Instruction>(Val)) {
1320+
for (auto V : DeducingPointers) {
1321+
UpHypothesis->InsertConstantValue(TR, V);
1322+
}
1323+
if (UpHypothesis->isInstructionInactiveFromOrigin(TR, VI)) {
1324+
Hypothesis->DeducingPointers.insert(Val);
1325+
if (EnzymePrintActivity)
1326+
llvm::errs() << " constant instruction hypothesis: " << *VI << "\n";
1327+
} else {
1328+
if (EnzymePrintActivity)
1329+
llvm::errs() << " cannot show constant instruction hypothesis: "
1330+
<< *VI << "\n";
1331+
}
1332+
}
12981333

12991334
auto checkActivity = [&](Instruction *I) {
13001335
if (notForAnalysis.count(I->getParent()))
@@ -1469,7 +1504,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
14691504
llvm::errs()
14701505
<< "potential active store via pointer in load: " << *I
14711506
<< " of " << *Val << "\n";
1472-
potentialStore = true;
1507+
potentiallyActiveStore = true;
14731508
}
14741509
}
14751510
} else if (auto MTI = dyn_cast<MemTransferInst>(I)) {
@@ -1480,7 +1515,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
14801515
llvm::errs()
14811516
<< "potential active store via pointer in memcpy: " << *I
14821517
<< " of " << *Val << "\n";
1483-
potentialStore = true;
1518+
potentiallyActiveStore = true;
14841519
}
14851520
}
14861521
} else {
@@ -1492,27 +1527,41 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
14921527
// A load that has as result an active pointer is not an active
14931528
// instruction, but does have an active value
14941529
if (!Hypothesis->isConstantInstruction(TR, I) ||
1495-
!Hypothesis->isConstantValue(TR, I)) {
1530+
(I != Val && !Hypothesis->isConstantValue(TR, I))) {
14961531
potentiallyActiveLoad = true;
14971532
// If this a potential pointer of pointer AND
1533+
// double** Val;
1534+
//
14981535
if (TR.query(Val)[{-1, -1}].isPossiblePointer()) {
1499-
// If this instruction either can store into the inner pointer,
1500-
// or could return an active loaded pointer(thus into a
1501-
// potential pointer of pointer
1502-
if (I->mayWriteToMemory() ||
1503-
(!Hypothesis->isConstantValue(TR, I) &&
1536+
// If this instruction either:
1537+
// 1) can actively store into the inner pointer, even
1538+
// if it doesn't store into the outer pointer. Actively
1539+
// storing into the outer pointer is handled by the isMod
1540+
// case.
1541+
// I(double** readonly Val, double activeX) {
1542+
// double* V0 = Val[0]
1543+
// V0 = activeX;
1544+
// }
1545+
// 2) may return an active pointer loaded from Val
1546+
// double* I = *Val;
1547+
// I[0] = active;
1548+
//
1549+
if ((I->mayWriteToMemory() &&
1550+
!Hypothesis->isConstantInstruction(TR, I)) ||
1551+
(!Hypothesis->DeducingPointers.count(I) &&
1552+
!Hypothesis->isConstantValue(TR, I) &&
15041553
TR.query(I)[{-1}].isPossiblePointer())) {
15051554
if (EnzymePrintActivity)
15061555
llvm::errs() << "potential active store via pointer in "
15071556
"unknown inst: "
15081557
<< *I << " of " << *Val << "\n";
1509-
potentialStore = true;
1558+
potentiallyActiveStore = true;
15101559
}
15111560
}
15121561
}
15131562
}
15141563
}
1515-
if (!potentialStore && isModSet(AARes)) {
1564+
if ((!potentiallyActiveStore || !potentialStore) && isModSet(AARes)) {
15161565
if (EnzymePrintActivity)
15171566
llvm::errs() << "potential active store: " << *I << " Val=" << *Val
15181567
<< "\n";
@@ -1522,39 +1571,80 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
15221571
llvm::errs() << " -- store potential activity: " << (int)cop
15231572
<< " - " << *SI << " of "
15241573
<< " Val=" << *Val << "\n";
1525-
potentialStore |= cop;
1574+
potentialStore = true;
1575+
if (cop)
1576+
potentiallyActiveStore = true;
15261577
} else if (auto MTI = dyn_cast<MemTransferInst>(I)) {
1527-
potentialStore |=
1528-
!Hypothesis->isConstantValue(TR, MTI->getArgOperand(1));
1578+
bool cop = !Hypothesis->isConstantValue(TR, MTI->getArgOperand(1));
1579+
potentialStore = true;
1580+
if (cop)
1581+
potentiallyActiveStore = true;
15291582
} else {
15301583
// Otherwise fallback and check if the instruction is active
15311584
// TODO: note that this can be optimized (especially for function
15321585
// calls)
1533-
potentialStore |= !Hypothesis->isConstantInstruction(TR, I);
1586+
auto cop = !Hypothesis->isConstantInstruction(TR, I);
1587+
if (EnzymePrintActivity)
1588+
llvm::errs() << " -- unknown store potential activity: " << (int)cop
1589+
<< " - " << *I << " of "
1590+
<< " Val=" << *Val << "\n";
1591+
potentialStore = true;
1592+
if (cop)
1593+
potentiallyActiveStore = true;
15341594
}
15351595
}
1536-
if (potentialStore && potentiallyActiveLoad)
1596+
if (potentiallyActiveStore && potentiallyActiveLoad)
15371597
return true;
15381598
return false;
15391599
};
15401600

15411601
// Search through all the instructions in this function
1542-
// for potential loads / stores of this value
1543-
for (BasicBlock &BB : *TR.getFunction()) {
1544-
if (notForAnalysis.count(&BB))
1545-
continue;
1546-
for (Instruction &I : BB) {
1547-
if (checkActivity(&I))
1548-
goto activeLoadAndStore;
1602+
// for potential loads / stores of this value.
1603+
//
1604+
// We can choose to only look at potential follower instructions
1605+
// if the value is created by the instruction (alloca, noalias)
1606+
// since no potentially active store to the same location can occur
1607+
// prior to its creation. Otherwise, check all instructions in the
1608+
// function as a store to an aliasing location may have occured
1609+
// prior to the instruction generating the value.
1610+
1611+
if (auto VI = dyn_cast<AllocaInst>(Val)) {
1612+
allFollowersOf(VI, checkActivity);
1613+
} else if (auto VI = dyn_cast<CallInst>(Val)) {
1614+
if (VI->hasRetAttr(Attribute::NoAlias))
1615+
allFollowersOf(VI, checkActivity);
1616+
else {
1617+
for (BasicBlock &BB : *TR.getFunction()) {
1618+
if (notForAnalysis.count(&BB))
1619+
continue;
1620+
for (Instruction &I : BB) {
1621+
if (checkActivity(&I))
1622+
goto activeLoadAndStore;
1623+
}
1624+
}
15491625
}
1626+
} else if (isa<Argument>(Val) || isa<Instruction>(Val)) {
1627+
for (BasicBlock &BB : *TR.getFunction()) {
1628+
if (notForAnalysis.count(&BB))
1629+
continue;
1630+
for (Instruction &I : BB) {
1631+
if (checkActivity(&I))
1632+
goto activeLoadAndStore;
1633+
}
1634+
}
1635+
} else {
1636+
llvm::errs() << "unknown pointer value type: " << *Val << "\n";
1637+
assert(0 && "unknown pointer value type");
1638+
llvm_unreachable("unknown pointer value type");
15501639
}
15511640

15521641
activeLoadAndStore:;
15531642
if (EnzymePrintActivity)
15541643
llvm::errs() << " </MEMSEARCH" << (int)directions << ">" << *Val
15551644
<< " potentiallyActiveLoad=" << potentiallyActiveLoad
1645+
<< " potentiallyActiveStore=" << potentiallyActiveStore
15561646
<< " potentialStore=" << potentialStore << "\n";
1557-
if (potentiallyActiveLoad && potentialStore) {
1647+
if (potentiallyActiveLoad && potentiallyActiveStore) {
15581648
insertAllFrom(TR, *Hypothesis, Val);
15591649
// TODO have insertall dependence on this
15601650
if (TmpOrig != Val)
@@ -1579,7 +1669,11 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
15791669

15801670
assert(UpHypothesis);
15811671
// UpHypothesis.ConstantValues.insert(val);
1582-
UpHypothesis->insertConstantsFrom(TR, *Hypothesis);
1672+
if (DeducingPointers.size() == 0)
1673+
UpHypothesis->insertConstantsFrom(TR, *Hypothesis);
1674+
for (auto V : DeducingPointers) {
1675+
UpHypothesis->InsertConstantValue(TR, V);
1676+
}
15831677
assert(directions & UP);
15841678
bool ActiveUp = !isa<Argument>(Val) &&
15851679
!UpHypothesis->isInstructionInactiveFromOrigin(TR, Val);
@@ -1671,7 +1765,8 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
16711765
} else {
16721766
InsertConstantValue(TR, Val);
16731767
insertConstantsFrom(TR, *Hypothesis);
1674-
insertConstantsFrom(TR, *UpHypothesis);
1768+
if (DeducingPointers.size() == 0)
1769+
insertConstantsFrom(TR, *UpHypothesis);
16751770
insertConstantsFrom(TR, *DownHypothesis);
16761771
return true;
16771772
}

enzyme/Enzyme/ActivityAnalysis.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ class ActivityAnalyzer {
9595
/// Values that may contain derivative information
9696
llvm::SmallPtrSet<llvm::Value *, 2> ActiveValues;
9797

98+
/// Intermediate pointers which are created by inactive instructions
99+
/// but are marked as active values to inductively determine their
100+
/// activity.
101+
llvm::SmallPtrSet<llvm::Value *, 1> DeducingPointers;
102+
98103
public:
99104
/// Construct the analyzer from the a previous set of constant and active
100105
/// values and whether returns are active. The all arguments of the functions
@@ -141,7 +146,8 @@ class ActivityAnalyzer {
141146
directions(directions),
142147
ConstantInstructions(Other.ConstantInstructions),
143148
ActiveInstructions(Other.ActiveInstructions),
144-
ConstantValues(Other.ConstantValues), ActiveValues(Other.ActiveValues) {
149+
ConstantValues(Other.ConstantValues), ActiveValues(Other.ActiveValues),
150+
DeducingPointers(Other.DeducingPointers) {
145151
assert(directions != 0);
146152
assert((directions & Other.directions) == directions);
147153
assert((directions & Other.directions) != 0);

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -672,11 +672,33 @@ class AdjointGenerator
672672
auto prediff = diffe(&I, Builder2);
673673
setDiffe(&I, Constant::getNullValue(type), Builder2);
674674

675+
if (mask && (!gutils->isConstantValue(I.getOperand(0)) ||
676+
!gutils->isConstantValue(orig_maskInit)))
677+
mask = lookup(mask, Builder2);
678+
675679
if (!gutils->isConstantValue(I.getOperand(0))) {
680+
BasicBlock *merge = nullptr;
681+
if (EnzymeRuntimeActivityCheck) {
682+
Value *shadow = Builder2.CreateICmpNE(
683+
lookup(gutils->getNewFromOriginal(I.getOperand(0)), Builder2),
684+
lookup(gutils->invertPointerM(I.getOperand(0), Builder2),
685+
Builder2));
686+
687+
BasicBlock *current = Builder2.GetInsertBlock();
688+
BasicBlock *conditional = gutils->addReverseBlock(
689+
current, current->getName() + "_active");
690+
merge = gutils->addReverseBlock(conditional,
691+
current->getName() + "_amerge");
692+
Builder2.CreateCondBr(shadow, conditional, merge);
693+
Builder2.SetInsertPoint(conditional);
694+
}
676695
((DiffeGradientUtils *)gutils)
677-
->addToInvertedPtrDiffe(
678-
I.getOperand(0), prediff, Builder2, alignment, OrigOffset,
679-
mask ? lookup(mask, Builder2) : nullptr);
696+
->addToInvertedPtrDiffe(I.getOperand(0), prediff, Builder2,
697+
alignment, OrigOffset, mask);
698+
if (merge) {
699+
Builder2.CreateBr(merge);
700+
Builder2.SetInsertPoint(merge);
701+
}
680702
}
681703
if (mask && !gutils->isConstantValue(orig_maskInit)) {
682704
addToDiffe(orig_maskInit, prediff, Builder2, isfloat,

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ llvm::cl::opt<bool> EnzymeInactiveDynamic(
8181
"enzyme-inactive-dynamic", cl::init(true), cl::Hidden,
8282
cl::desc("Force wholy inactive dynamic loops to have 0 iter reverse pass"));
8383

84+
llvm::cl::opt<bool>
85+
EnzymeRuntimeActivityCheck("enzyme-runtime-activity", cl::init(false),
86+
cl::Hidden,
87+
cl::desc("Perform runtime activity checks"));
88+
8489
llvm::cl::opt<bool>
8590
EnzymeSharedForward("enzyme-shared-forward", cl::init(false), cl::Hidden,
8691
cl::desc("Forward Shared Memory from definitions"));

enzyme/Enzyme/GradientUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ extern std::map<
9696
customFwdCallHandlers;
9797

9898
extern "C" {
99+
extern llvm::cl::opt<bool> EnzymeRuntimeActivityCheck;
99100
extern llvm::cl::opt<bool> EnzymeInactiveDynamic;
100101
extern llvm::cl::opt<bool> EnzymeFreeInternalAllocations;
101102
extern llvm::cl::opt<bool> EnzymeRematerialize;
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
; RUN: %opt < %s %loadEnzyme -print-activity-analysis -activity-analysis-func=matvec -activity-analysis-inactive-args -o /dev/null | FileCheck %s
2+
3+
source_filename = "text"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
5+
target triple = "x86_64-pc-linux-gnu"
6+
7+
declare float** @jl_array_copy()
8+
9+
define float @matvec({} addrspace(10)* nocapture nonnull readonly align 16 dereferenceable(40) %arg, {} addrspace(10)* nonnull align 16 dereferenceable(40) %arg1, i8 zeroext %arg2) {
10+
entry:
11+
%i10 = call noalias float** @jl_array_copy()
12+
%i11 = load float*, float** %i10, align 8
13+
%i12 = load float, float* %i11, align 4;, !tbaa !21
14+
ret float %i12
15+
}
16+
17+
; CHECK: {} addrspace(10)* %arg: icv:1
18+
; CHECK: {} addrspace(10)* %arg1: icv:1
19+
; CHECK: i8 %arg2: icv:1
20+
; CHECK: entry
21+
; CHECK-NEXT: %i10 = call noalias float** @jl_array_copy(): icv:1 ici:1
22+
; CHECK-NEXT: %i11 = load float*, float** %i10, align 8: icv:1 ici:1
23+
; CHECK-NEXT: %i12 = load float, float* %i11, align 4: icv:1 ici:1
24+
; CHECK-NEXT: ret float %i12: icv:1 ici:1

0 commit comments

Comments
 (0)