@@ -92,6 +92,7 @@ static inline bool is_use_directly_needed_in_reverse(
92
92
}
93
93
if (MTI->getArgOperand (2 ) != val)
94
94
return false ;
95
+ return !gutils->isConstantInstruction (MTI);
95
96
}
96
97
97
98
// Preserve the length of memsets of backward creation shadows
@@ -264,8 +265,12 @@ static inline bool is_value_needed_in_reverse(
264
265
const Instruction *user = dyn_cast<Instruction>(use);
265
266
266
267
// A shadow value is only needed in reverse if it or one of its descendants
267
- // is used in an active instruction
268
- if (VT == ValueType::Shadow) {
268
+ // is used in an active instruction.
269
+ // If inst is a constant value, the primal may be used in its place and
270
+ // thus required.
271
+ if (VT == ValueType::Shadow ||
272
+ (gutils->isConstantValue (const_cast <Value *>(inst)) &&
273
+ !TR.query (const_cast <Value *>(inst))[{-1 }].isFloat ())) {
269
274
if (!user)
270
275
return seen[idx] = true ;
271
276
@@ -286,25 +291,25 @@ static inline bool is_value_needed_in_reverse(
286
291
break ;
287
292
}
288
293
if (!rematerialized)
289
- continue ;
294
+ goto endShadow ;
290
295
}
291
296
292
297
if (!gutils->isConstantValue (
293
298
const_cast <Value *>(SI->getPointerOperand ())))
294
299
return seen[idx] = true ;
295
300
else
296
- continue ;
301
+ goto endShadow ;
297
302
}
298
303
299
304
if (auto MTI = dyn_cast<MemTransferInst>(user)) {
300
305
if (MTI->getArgOperand (0 ) != inst && MTI->getArgOperand (1 ) != inst)
301
- continue ;
306
+ goto endShadow ;
302
307
303
308
if (!gutils->isConstantValue (
304
309
const_cast <Value *>(MTI->getArgOperand (0 ))))
305
310
return seen[idx] = true ;
306
311
else
307
- continue ;
312
+ goto endShadow ;
308
313
}
309
314
310
315
if (auto CI = dyn_cast<CallInst>(user)) {
@@ -327,50 +332,50 @@ static inline bool is_value_needed_in_reverse(
327
332
// Only need shadow request for reverse
328
333
if (funcName == " MPI_Irecv" || funcName == " PMPI_Irecv" ) {
329
334
if (gutils->isConstantInstruction (const_cast <Instruction *>(user)))
330
- continue ;
335
+ goto endShadow ;
331
336
// Need shadow request
332
337
if (inst == CI->getArgOperand (6 ))
333
338
return seen[idx] = true ;
334
339
// Need shadow buffer in forward pass
335
340
if (mode != DerivativeMode::ReverseModeGradient)
336
341
if (inst == CI->getArgOperand (0 ))
337
342
return seen[idx] = true ;
338
- continue ;
343
+ goto endShadow ;
339
344
}
340
345
if (funcName == " MPI_Isend" || funcName == " PMPI_Isend" ) {
341
346
if (gutils->isConstantInstruction (const_cast <Instruction *>(user)))
342
- continue ;
347
+ goto endShadow ;
343
348
// Need shadow request
344
349
if (inst == CI->getArgOperand (6 ))
345
350
return seen[idx] = true ;
346
351
// Need shadow buffer in reverse pass or forward mode
347
352
if (inst == CI->getArgOperand (0 ))
348
353
return seen[idx] = true ;
349
- continue ;
354
+ goto endShadow ;
350
355
}
351
356
352
357
// Don't need shadow of anything (all via cache for reverse),
353
358
// but need shadow of request for primal.
354
359
if (funcName == " MPI_Wait" || funcName == " PMPI_Wait" ) {
355
360
if (gutils->isConstantInstruction (const_cast <Instruction *>(user)))
356
- continue ;
361
+ goto endShadow ;
357
362
// Need shadow request in forward pass only
358
363
if (mode != DerivativeMode::ReverseModeGradient)
359
364
if (inst == CI->getArgOperand (0 ))
360
365
return seen[idx] = true ;
361
- continue ;
366
+ goto endShadow ;
362
367
}
363
368
364
369
// Don't need shadow of anything (all via cache for reverse),
365
370
// but need shadow of request for primal.
366
371
if (funcName == " MPI_Waitall" || funcName == " PMPI_Waitall" ) {
367
372
if (gutils->isConstantInstruction (const_cast <Instruction *>(user)))
368
- continue ;
373
+ goto endShadow ;
369
374
// Need shadow request in forward pass
370
375
if (mode != DerivativeMode::ReverseModeGradient)
371
376
if (inst == CI->getArgOperand (1 ))
372
377
return seen[idx] = true ;
373
- continue ;
378
+ goto endShadow ;
374
379
}
375
380
376
381
// Use in a write barrier requires the shadow in the forward, even
@@ -398,7 +403,7 @@ static inline bool is_value_needed_in_reverse(
398
403
gutils->ATA ->ActiveReturns == DIFFE_TYPE::DUP_NONEED)
399
404
return seen[idx] = true ;
400
405
else
401
- continue ;
406
+ goto endShadow ;
402
407
}
403
408
404
409
// Assume active instructions require the operand.
@@ -411,18 +416,20 @@ static inline bool is_value_needed_in_reverse(
411
416
// in the forward pass, for example double* x = load double** y
412
417
// is a constant instruction, but needed in the forward
413
418
if (user->getType ()->isVoidTy ())
414
- continue ;
419
+ goto endShadow ;
415
420
416
421
if (!TR.query (const_cast <Instruction *>(user))
417
422
.Inner0 ()
418
423
.isPossiblePointer ())
419
- continue ;
424
+ goto endShadow ;
420
425
421
426
if (!OneLevel && is_value_needed_in_reverse<ValueType::Shadow>(
422
427
gutils, user, mode, seen, oldUnreachable)) {
423
428
return seen[idx] = true ;
424
429
}
425
- continue ;
430
+ endShadow:
431
+ if (VT != ValueType::Primal)
432
+ continue ;
426
433
}
427
434
428
435
assert (VT == ValueType::Primal);
0 commit comments