Skip to content

Commit 902f67e

Browse files
authored
Fix recurrence (rust-lang#575)
* Forbid illegal phi recompute and use caching heuristic * Fix recur
1 parent 577b291 commit 902f67e

File tree

2 files changed

+296
-75
lines changed

2 files changed

+296
-75
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 102 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -198,66 +198,28 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
198198
}
199199
}
200200

201-
if (this->mode == DerivativeMode::ReverseModeGradient)
201+
if (this->mode == DerivativeMode::ReverseModeGradient ||
202+
this->mode == DerivativeMode::ForwardModeSplit ||
203+
this->mode == DerivativeMode::ReverseModeCombined)
202204
if (auto inst = dyn_cast<Instruction>(val)) {
203-
if (unwrapMode == UnwrapMode::LegalFullUnwrap) {
204-
// TODO this isOriginal is a bottleneck, the new mapping of
205-
// knownRecompute should be precomputed and maintained to lookup instead
206-
Instruction *orig = isOriginal(inst);
207-
// If a given value has been chosen to be cached, do not compute the
208-
// operands to unwrap it, instead simply emit a placeholder to be
209-
// replaced by the cache load later. This placeholder should only be
210-
// returned when the original value would be recomputed (e.g. this
211-
// function would not return null). Since this case assumes everything
212-
// can be recomputed, simply return the placeholder.
213-
if (orig && knownRecomputeHeuristic.find(orig) !=
214-
knownRecomputeHeuristic.end()) {
215-
if (!knownRecomputeHeuristic[orig]) {
216-
assert(inst->getParent()->getParent() == newFunc);
217-
auto placeholder = BuilderM.CreatePHI(
218-
val->getType(), 0, val->getName() + "_krcLFUreplacement");
219-
unwrappedLoads[placeholder] = inst;
220-
SmallVector<Metadata *, 1> avail;
221-
for (auto pair : available)
222-
if (pair.second)
223-
avail.push_back(MDNode::get(
224-
placeholder->getContext(),
225-
{ValueAsMetadata::get(const_cast<Value *>(pair.first)),
226-
ValueAsMetadata::get(pair.second)}));
227-
placeholder->setMetadata(
228-
"enzyme_available",
229-
MDNode::get(placeholder->getContext(), avail));
230-
if (!permitCache)
231-
return placeholder;
232-
return unwrap_cache[BuilderM.GetInsertBlock()][idx.first]
233-
[idx.second] = placeholder;
234-
}
235-
}
236-
} else if (unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) {
237-
// TODO this isOriginal is a bottleneck, the new mapping of
238-
// knownRecompute should be precomputed and maintained to lookup instead
239-
Instruction *orig = isOriginal(inst);
240-
// If a given value has been chosen to be cached, do not compute the
241-
// operands to unwrap it, instead simply emit a placeholder to be
242-
// replaced by the cache load later. This placeholder should only be
243-
// returned when the original value would be recomputed (e.g. this
244-
// function would not return null). See note below about the condition
245-
// as applied to this case.
246-
if (orig && knownRecomputeHeuristic.find(orig) !=
247-
knownRecomputeHeuristic.end()) {
248-
if (!knownRecomputeHeuristic[orig]) {
249-
// Note that this logic (original load must dominate or
250-
// alternatively be in the reverse block) is only valid iff when
251-
// applicable (here if in split mode), an uncacheable load cannot be
252-
// hoisted outside of a loop to be used as a loop limit. This
253-
// optimization is currently done in the combined mode (e.g. if a
254-
// load isn't modified between a prior insertion point and the
255-
// actual load, it is legal to recompute).
256-
if (!isOriginalBlock(*BuilderM.GetInsertBlock()) ||
257-
DT.dominates(inst, &*BuilderM.GetInsertPoint())) {
205+
if (inst->getParent()->getParent() == newFunc) {
206+
if (unwrapMode == UnwrapMode::LegalFullUnwrap) {
207+
// TODO this isOriginal is a bottleneck, the new mapping of
208+
// knownRecompute should be precomputed and maintained to lookup
209+
// instead
210+
Instruction *orig = isOriginal(inst);
211+
// If a given value has been chosen to be cached, do not compute the
212+
// operands to unwrap it, instead simply emit a placeholder to be
213+
// replaced by the cache load later. This placeholder should only be
214+
// returned when the original value would be recomputed (e.g. this
215+
// function would not return null). Since this case assumes everything
216+
// can be recomputed, simply return the placeholder.
217+
if (orig && knownRecomputeHeuristic.find(orig) !=
218+
knownRecomputeHeuristic.end()) {
219+
if (!knownRecomputeHeuristic[orig]) {
258220
assert(inst->getParent()->getParent() == newFunc);
259221
auto placeholder = BuilderM.CreatePHI(
260-
val->getType(), 0, val->getName() + "_krcAFUWLreplacement");
222+
val->getType(), 0, val->getName() + "_krcLFUreplacement");
261223
unwrappedLoads[placeholder] = inst;
262224
SmallVector<Metadata *, 1> avail;
263225
for (auto pair : available)
@@ -275,24 +237,85 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
275237
[idx.second] = placeholder;
276238
}
277239
}
278-
}
279-
} else if (unwrapMode != UnwrapMode::LegalFullUnwrapNoTapeReplace) {
280-
// TODO this isOriginal is a bottleneck, the new mapping of
281-
// knownRecompute should be precomputed and maintained to lookup instead
282-
283-
// If a given value has been chosen to be cached, do not compute the
284-
// operands to unwrap it if it is not legal to do so. This prevents the
285-
// creation of unused versions of the instruction's operand, which may
286-
// be assumed to never be used and thus cause an error when they are
287-
// inadvertantly cached.
288-
Value *orig = isOriginal(val);
289-
if (orig && knownRecomputeHeuristic.find(orig) !=
290-
knownRecomputeHeuristic.end()) {
291-
if (!knownRecomputeHeuristic[orig]) {
292-
if (!legalRecompute(orig, available, &BuilderM))
293-
return nullptr;
294-
295-
assert(isa<LoadInst>(orig) == isa<LoadInst>(val));
240+
} else if (unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) {
241+
// TODO this isOriginal is a bottleneck, the new mapping of
242+
// knownRecompute should be precomputed and maintained to lookup
243+
// instead
244+
Instruction *orig = isOriginal(inst);
245+
// If a given value has been chosen to be cached, do not compute the
246+
// operands to unwrap it, instead simply emit a placeholder to be
247+
// replaced by the cache load later. This placeholder should only be
248+
// returned when the original value would be recomputed (e.g. this
249+
// function would not return null). See note below about the condition
250+
// as applied to this case.
251+
if (orig && knownRecomputeHeuristic.find(orig) !=
252+
knownRecomputeHeuristic.end()) {
253+
if (!knownRecomputeHeuristic[orig]) {
254+
if (mode == DerivativeMode::ReverseModeCombined) {
255+
// Don't unnecessarily cache a value if the caching
256+
// heuristic says we should preserve this precise (and not
257+
// an lcssa wrapped) value
258+
if (!isOriginalBlock(*BuilderM.GetInsertBlock())) {
259+
Value *nval = inst;
260+
if (scope)
261+
nval = fixLCSSA(inst, scope);
262+
if (nval == inst)
263+
goto endCheck;
264+
}
265+
} else {
266+
// Note that this logic (original load must dominate or
267+
// alternatively be in the reverse block) is only valid iff when
268+
// applicable (here if in split mode), an uncacheable load
269+
// cannot be hoisted outside of a loop to be used as a loop
270+
// limit. This optimization is currently done in the combined
271+
// mode (e.g. if a load isn't modified between a prior insertion
272+
// point and the actual load, it is legal to recompute).
273+
if (!isOriginalBlock(*BuilderM.GetInsertBlock()) ||
274+
DT.dominates(inst, &*BuilderM.GetInsertPoint())) {
275+
assert(inst->getParent()->getParent() == newFunc);
276+
auto placeholder = BuilderM.CreatePHI(
277+
val->getType(), 0,
278+
val->getName() + "_krcAFUWLreplacement");
279+
unwrappedLoads[placeholder] = inst;
280+
SmallVector<Metadata *, 1> avail;
281+
for (auto pair : available)
282+
if (pair.second)
283+
avail.push_back(
284+
MDNode::get(placeholder->getContext(),
285+
{ValueAsMetadata::get(
286+
const_cast<Value *>(pair.first)),
287+
ValueAsMetadata::get(pair.second)}));
288+
placeholder->setMetadata(
289+
"enzyme_available",
290+
MDNode::get(placeholder->getContext(), avail));
291+
if (!permitCache)
292+
return placeholder;
293+
return unwrap_cache[BuilderM.GetInsertBlock()][idx.first]
294+
[idx.second] = placeholder;
295+
}
296+
}
297+
}
298+
}
299+
} else if (unwrapMode != UnwrapMode::LegalFullUnwrapNoTapeReplace &&
300+
mode != DerivativeMode::ReverseModeCombined) {
301+
// TODO this isOriginal is a bottleneck, the new mapping of
302+
// knownRecompute should be precomputed and maintained to lookup
303+
// instead
304+
305+
// If a given value has been chosen to be cached, do not compute the
306+
// operands to unwrap it if it is not legal to do so. This prevents
307+
// the creation of unused versions of the instruction's operand, which
308+
// may be assumed to never be used and thus cause an error when they
309+
// are inadvertantly cached.
310+
Value *orig = isOriginal(val);
311+
if (orig && knownRecomputeHeuristic.find(orig) !=
312+
knownRecomputeHeuristic.end()) {
313+
if (!knownRecomputeHeuristic[orig]) {
314+
if (!legalRecompute(orig, available, &BuilderM))
315+
return nullptr;
316+
317+
assert(isa<LoadInst>(orig) == isa<LoadInst>(val));
318+
}
296319
}
297320
}
298321
}
@@ -925,6 +948,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
925948
if (L->contains(PH))
926949
prevIteration.insert(PH);
927950
}
951+
if (prevIteration.size() && !legalRecompute(phi, available, &BuilderM)) {
952+
assert(unwrapMode != UnwrapMode::LegalFullUnwrap);
953+
goto endCheck;
954+
}
928955
}
929956
for (auto &val : phi->incoming_values()) {
930957
if (isPotentialLastLoopValue(val, parent, LLI)) {

0 commit comments

Comments
 (0)