Skip to content

Commit dfadccf

Browse files
authored
Primal activity analysis fix (rust-lang#678)
* Primal activity analysis fix * bugfix * Fix tests
1 parent 85524a3 commit dfadccf

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ static inline bool is_use_directly_needed_in_reverse(
9292
}
9393
if (MTI->getArgOperand(2) != val)
9494
return false;
95+
return !gutils->isConstantInstruction(MTI);
9596
}
9697

9798
// Preserve the length of memsets of backward creation shadows
@@ -264,8 +265,12 @@ static inline bool is_value_needed_in_reverse(
264265
const Instruction *user = dyn_cast<Instruction>(use);
265266

266267
// A shadow value is only needed in reverse if it or one of its descendants
267-
// is used in an active instruction
268-
if (VT == ValueType::Shadow) {
268+
// is used in an active instruction.
269+
// If inst is a constant value, the primal may be used in its place and
270+
// thus required.
271+
if (VT == ValueType::Shadow ||
272+
(gutils->isConstantValue(const_cast<Value *>(inst)) &&
273+
!TR.query(const_cast<Value *>(inst))[{-1}].isFloat())) {
269274
if (!user)
270275
return seen[idx] = true;
271276

@@ -286,25 +291,25 @@ static inline bool is_value_needed_in_reverse(
286291
break;
287292
}
288293
if (!rematerialized)
289-
continue;
294+
goto endShadow;
290295
}
291296

292297
if (!gutils->isConstantValue(
293298
const_cast<Value *>(SI->getPointerOperand())))
294299
return seen[idx] = true;
295300
else
296-
continue;
301+
goto endShadow;
297302
}
298303

299304
if (auto MTI = dyn_cast<MemTransferInst>(user)) {
300305
if (MTI->getArgOperand(0) != inst && MTI->getArgOperand(1) != inst)
301-
continue;
306+
goto endShadow;
302307

303308
if (!gutils->isConstantValue(
304309
const_cast<Value *>(MTI->getArgOperand(0))))
305310
return seen[idx] = true;
306311
else
307-
continue;
312+
goto endShadow;
308313
}
309314

310315
if (auto CI = dyn_cast<CallInst>(user)) {
@@ -327,50 +332,50 @@ static inline bool is_value_needed_in_reverse(
327332
// Only need shadow request for reverse
328333
if (funcName == "MPI_Irecv" || funcName == "PMPI_Irecv") {
329334
if (gutils->isConstantInstruction(const_cast<Instruction *>(user)))
330-
continue;
335+
goto endShadow;
331336
// Need shadow request
332337
if (inst == CI->getArgOperand(6))
333338
return seen[idx] = true;
334339
// Need shadow buffer in forward pass
335340
if (mode != DerivativeMode::ReverseModeGradient)
336341
if (inst == CI->getArgOperand(0))
337342
return seen[idx] = true;
338-
continue;
343+
goto endShadow;
339344
}
340345
if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") {
341346
if (gutils->isConstantInstruction(const_cast<Instruction *>(user)))
342-
continue;
347+
goto endShadow;
343348
// Need shadow request
344349
if (inst == CI->getArgOperand(6))
345350
return seen[idx] = true;
346351
// Need shadow buffer in reverse pass or forward mode
347352
if (inst == CI->getArgOperand(0))
348353
return seen[idx] = true;
349-
continue;
354+
goto endShadow;
350355
}
351356

352357
// Don't need shadow of anything (all via cache for reverse),
353358
// but need shadow of request for primal.
354359
if (funcName == "MPI_Wait" || funcName == "PMPI_Wait") {
355360
if (gutils->isConstantInstruction(const_cast<Instruction *>(user)))
356-
continue;
361+
goto endShadow;
357362
// Need shadow request in forward pass only
358363
if (mode != DerivativeMode::ReverseModeGradient)
359364
if (inst == CI->getArgOperand(0))
360365
return seen[idx] = true;
361-
continue;
366+
goto endShadow;
362367
}
363368

364369
// Don't need shadow of anything (all via cache for reverse),
365370
// but need shadow of request for primal.
366371
if (funcName == "MPI_Waitall" || funcName == "PMPI_Waitall") {
367372
if (gutils->isConstantInstruction(const_cast<Instruction *>(user)))
368-
continue;
373+
goto endShadow;
369374
// Need shadow request in forward pass
370375
if (mode != DerivativeMode::ReverseModeGradient)
371376
if (inst == CI->getArgOperand(1))
372377
return seen[idx] = true;
373-
continue;
378+
goto endShadow;
374379
}
375380

376381
// Use in a write barrier requires the shadow in the forward, even
@@ -398,7 +403,7 @@ static inline bool is_value_needed_in_reverse(
398403
gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_NONEED)
399404
return seen[idx] = true;
400405
else
401-
continue;
406+
goto endShadow;
402407
}
403408

404409
// Assume active instructions require the operand.
@@ -411,18 +416,20 @@ static inline bool is_value_needed_in_reverse(
411416
// in the forward pass, for example double* x = load double** y
412417
// is a constant instruction, but needed in the forward
413418
if (user->getType()->isVoidTy())
414-
continue;
419+
goto endShadow;
415420

416421
if (!TR.query(const_cast<Instruction *>(user))
417422
.Inner0()
418423
.isPossiblePointer())
419-
continue;
424+
goto endShadow;
420425

421426
if (!OneLevel && is_value_needed_in_reverse<ValueType::Shadow>(
422427
gutils, user, mode, seen, oldUnreachable)) {
423428
return seen[idx] = true;
424429
}
425-
continue;
430+
endShadow:
431+
if (VT != ValueType::Primal)
432+
continue;
426433
}
427434

428435
assert(VT == ValueType::Primal);

0 commit comments

Comments
 (0)