@@ -198,66 +198,28 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
198
198
}
199
199
}
200
200
201
- if (this ->mode == DerivativeMode::ReverseModeGradient)
201
+ if (this ->mode == DerivativeMode::ReverseModeGradient ||
202
+ this ->mode == DerivativeMode::ForwardModeSplit ||
203
+ this ->mode == DerivativeMode::ReverseModeCombined)
202
204
if (auto inst = dyn_cast<Instruction>(val)) {
203
- if (unwrapMode == UnwrapMode::LegalFullUnwrap) {
204
- // TODO this isOriginal is a bottleneck, the new mapping of
205
- // knownRecompute should be precomputed and maintained to lookup instead
206
- Instruction *orig = isOriginal (inst);
207
- // If a given value has been chosen to be cached, do not compute the
208
- // operands to unwrap it, instead simply emit a placeholder to be
209
- // replaced by the cache load later. This placeholder should only be
210
- // returned when the original value would be recomputed (e.g. this
211
- // function would not return null). Since this case assumes everything
212
- // can be recomputed, simply return the placeholder.
213
- if (orig && knownRecomputeHeuristic.find (orig) !=
214
- knownRecomputeHeuristic.end ()) {
215
- if (!knownRecomputeHeuristic[orig]) {
216
- assert (inst->getParent ()->getParent () == newFunc);
217
- auto placeholder = BuilderM.CreatePHI (
218
- val->getType (), 0 , val->getName () + " _krcLFUreplacement" );
219
- unwrappedLoads[placeholder] = inst;
220
- SmallVector<Metadata *, 1 > avail;
221
- for (auto pair : available)
222
- if (pair.second )
223
- avail.push_back (MDNode::get (
224
- placeholder->getContext (),
225
- {ValueAsMetadata::get (const_cast <Value *>(pair.first )),
226
- ValueAsMetadata::get (pair.second )}));
227
- placeholder->setMetadata (
228
- " enzyme_available" ,
229
- MDNode::get (placeholder->getContext (), avail));
230
- if (!permitCache)
231
- return placeholder;
232
- return unwrap_cache[BuilderM.GetInsertBlock ()][idx.first ]
233
- [idx.second ] = placeholder;
234
- }
235
- }
236
- } else if (unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) {
237
- // TODO this isOriginal is a bottleneck, the new mapping of
238
- // knownRecompute should be precomputed and maintained to lookup instead
239
- Instruction *orig = isOriginal (inst);
240
- // If a given value has been chosen to be cached, do not compute the
241
- // operands to unwrap it, instead simply emit a placeholder to be
242
- // replaced by the cache load later. This placeholder should only be
243
- // returned when the original value would be recomputed (e.g. this
244
- // function would not return null). See note below about the condition
245
- // as applied to this case.
246
- if (orig && knownRecomputeHeuristic.find (orig) !=
247
- knownRecomputeHeuristic.end ()) {
248
- if (!knownRecomputeHeuristic[orig]) {
249
- // Note that this logic (original load must dominate or
250
- // alternatively be in the reverse block) is only valid iff when
251
- // applicable (here if in split mode), an uncacheable load cannot be
252
- // hoisted outside of a loop to be used as a loop limit. This
253
- // optimization is currently done in the combined mode (e.g. if a
254
- // load isn't modified between a prior insertion point and the
255
- // actual load, it is legal to recompute).
256
- if (!isOriginalBlock (*BuilderM.GetInsertBlock ()) ||
257
- DT.dominates (inst, &*BuilderM.GetInsertPoint ())) {
205
+ if (inst->getParent ()->getParent () == newFunc) {
206
+ if (unwrapMode == UnwrapMode::LegalFullUnwrap) {
207
+ // TODO this isOriginal is a bottleneck, the new mapping of
208
+ // knownRecompute should be precomputed and maintained to lookup
209
+ // instead
210
+ Instruction *orig = isOriginal (inst);
211
+ // If a given value has been chosen to be cached, do not compute the
212
+ // operands to unwrap it, instead simply emit a placeholder to be
213
+ // replaced by the cache load later. This placeholder should only be
214
+ // returned when the original value would be recomputed (e.g. this
215
+ // function would not return null). Since this case assumes everything
216
+ // can be recomputed, simply return the placeholder.
217
+ if (orig && knownRecomputeHeuristic.find (orig) !=
218
+ knownRecomputeHeuristic.end ()) {
219
+ if (!knownRecomputeHeuristic[orig]) {
258
220
assert (inst->getParent ()->getParent () == newFunc);
259
221
auto placeholder = BuilderM.CreatePHI (
260
- val->getType (), 0 , val->getName () + " _krcAFUWLreplacement " );
222
+ val->getType (), 0 , val->getName () + " _krcLFUreplacement " );
261
223
unwrappedLoads[placeholder] = inst;
262
224
SmallVector<Metadata *, 1 > avail;
263
225
for (auto pair : available)
@@ -275,24 +237,85 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
275
237
[idx.second ] = placeholder;
276
238
}
277
239
}
278
- }
279
- } else if (unwrapMode != UnwrapMode::LegalFullUnwrapNoTapeReplace) {
280
- // TODO this isOriginal is a bottleneck, the new mapping of
281
- // knownRecompute should be precomputed and maintained to lookup instead
282
-
283
- // If a given value has been chosen to be cached, do not compute the
284
- // operands to unwrap it if it is not legal to do so. This prevents the
285
- // creation of unused versions of the instruction's operand, which may
286
- // be assumed to never be used and thus cause an error when they are
287
- // inadvertantly cached.
288
- Value *orig = isOriginal (val);
289
- if (orig && knownRecomputeHeuristic.find (orig) !=
290
- knownRecomputeHeuristic.end ()) {
291
- if (!knownRecomputeHeuristic[orig]) {
292
- if (!legalRecompute (orig, available, &BuilderM))
293
- return nullptr ;
294
-
295
- assert (isa<LoadInst>(orig) == isa<LoadInst>(val));
240
+ } else if (unwrapMode == UnwrapMode::AttemptFullUnwrapWithLookup) {
241
+ // TODO this isOriginal is a bottleneck, the new mapping of
242
+ // knownRecompute should be precomputed and maintained to lookup
243
+ // instead
244
+ Instruction *orig = isOriginal (inst);
245
+ // If a given value has been chosen to be cached, do not compute the
246
+ // operands to unwrap it, instead simply emit a placeholder to be
247
+ // replaced by the cache load later. This placeholder should only be
248
+ // returned when the original value would be recomputed (e.g. this
249
+ // function would not return null). See note below about the condition
250
+ // as applied to this case.
251
+ if (orig && knownRecomputeHeuristic.find (orig) !=
252
+ knownRecomputeHeuristic.end ()) {
253
+ if (!knownRecomputeHeuristic[orig]) {
254
+ if (mode == DerivativeMode::ReverseModeCombined) {
255
+ // Don't unnecessarily cache a value if the caching
256
+ // heuristic says we should preserve this precise (and not
257
+ // an lcssa wrapped) value
258
+ if (!isOriginalBlock (*BuilderM.GetInsertBlock ())) {
259
+ Value *nval = inst;
260
+ if (scope)
261
+ nval = fixLCSSA (inst, scope);
262
+ if (nval == inst)
263
+ goto endCheck;
264
+ }
265
+ } else {
266
+ // Note that this logic (original load must dominate or
267
+ // alternatively be in the reverse block) is only valid iff when
268
+ // applicable (here if in split mode), an uncacheable load
269
+ // cannot be hoisted outside of a loop to be used as a loop
270
+ // limit. This optimization is currently done in the combined
271
+ // mode (e.g. if a load isn't modified between a prior insertion
272
+ // point and the actual load, it is legal to recompute).
273
+ if (!isOriginalBlock (*BuilderM.GetInsertBlock ()) ||
274
+ DT.dominates (inst, &*BuilderM.GetInsertPoint ())) {
275
+ assert (inst->getParent ()->getParent () == newFunc);
276
+ auto placeholder = BuilderM.CreatePHI (
277
+ val->getType (), 0 ,
278
+ val->getName () + " _krcAFUWLreplacement" );
279
+ unwrappedLoads[placeholder] = inst;
280
+ SmallVector<Metadata *, 1 > avail;
281
+ for (auto pair : available)
282
+ if (pair.second )
283
+ avail.push_back (
284
+ MDNode::get (placeholder->getContext (),
285
+ {ValueAsMetadata::get (
286
+ const_cast <Value *>(pair.first )),
287
+ ValueAsMetadata::get (pair.second )}));
288
+ placeholder->setMetadata (
289
+ " enzyme_available" ,
290
+ MDNode::get (placeholder->getContext (), avail));
291
+ if (!permitCache)
292
+ return placeholder;
293
+ return unwrap_cache[BuilderM.GetInsertBlock ()][idx.first ]
294
+ [idx.second ] = placeholder;
295
+ }
296
+ }
297
+ }
298
+ }
299
+ } else if (unwrapMode != UnwrapMode::LegalFullUnwrapNoTapeReplace &&
300
+ mode != DerivativeMode::ReverseModeCombined) {
301
+ // TODO this isOriginal is a bottleneck, the new mapping of
302
+ // knownRecompute should be precomputed and maintained to lookup
303
+ // instead
304
+
305
+ // If a given value has been chosen to be cached, do not compute the
306
+ // operands to unwrap it if it is not legal to do so. This prevents
307
+ // the creation of unused versions of the instruction's operand, which
308
+ // may be assumed to never be used and thus cause an error when they
309
+ // are inadvertantly cached.
310
+ Value *orig = isOriginal (val);
311
+ if (orig && knownRecomputeHeuristic.find (orig) !=
312
+ knownRecomputeHeuristic.end ()) {
313
+ if (!knownRecomputeHeuristic[orig]) {
314
+ if (!legalRecompute (orig, available, &BuilderM))
315
+ return nullptr ;
316
+
317
+ assert (isa<LoadInst>(orig) == isa<LoadInst>(val));
318
+ }
296
319
}
297
320
}
298
321
}
@@ -925,6 +948,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
925
948
if (L->contains (PH))
926
949
prevIteration.insert (PH);
927
950
}
951
+ if (prevIteration.size () && !legalRecompute (phi, available, &BuilderM)) {
952
+ assert (unwrapMode != UnwrapMode::LegalFullUnwrap);
953
+ goto endCheck;
954
+ }
928
955
}
929
956
for (auto &val : phi->incoming_values ()) {
930
957
if (isPotentialLastLoopValue (val, parent, LLI)) {
0 commit comments