Skip to content

Commit 25bf341

Browse files
committed
Generalize call cache behavior
1 parent 035b343 commit 25bf341

File tree

3 files changed

+9
-19
lines changed

3 files changed

+9
-19
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6487,7 +6487,8 @@ class AdjointGenerator
64876487
// may load uncacheable data)
64886488
// Store and reload it
64896489
if (Mode != DerivativeMode::ReverseModeCombined && subretused &&
6490-
!orig->doesNotAccessMemory()) {
6490+
(orig->mayWriteToMemory() ||
6491+
!gutils->legalRecompute(orig, ValueToValueMapTy(), nullptr))) {
64916492
if (!gutils->unnecessaryIntermediates.count(orig)) {
64926493
gutils->cacheForReverse(BuilderZ, newCall,
64936494
getIndex(orig, CacheType::Self));

enzyme/Enzyme/Enzyme.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -816,10 +816,13 @@ class Enzyme : public ModulePass {
816816
if (Fn->getName().contains("__enzyme_call_inactive")) {
817817
InactiveCalls.insert(CI);
818818
}
819-
if (F.getName() == "omp_get_max_threads" ||
820-
F.getName() == "omp_get_thread_num") {
821-
F.addFnAttr(Attribute::ReadOnly);
822-
F.addFnAttr(Attribute::InaccessibleMemOnly);
819+
if (Fn->getName() == "omp_get_max_threads" ||
820+
Fn->getName() == "omp_get_thread_num") {
821+
Fn->addFnAttr(Attribute::ReadOnly);
822+
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly);
823+
Fn->addFnAttr(Attribute::InaccessibleMemOnly);
824+
CI->addAttribute(AttributeList::FunctionIndex,
825+
Attribute::InaccessibleMemOnly);
823826
}
824827
if ((Fn->getName() == "cblas_ddot" || Fn->getName() == "cblas_sdot") &&
825828
Fn->isDeclaration()) {

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,6 @@ struct CacheAnalysis {
195195
// Pointer operands originating from call instructions that are not
196196
// malloc/free are conservatively considered uncacheable.
197197
if (auto obj_op = dyn_cast<CallInst>(obj)) {
198-
Function *called = obj_op->getCalledFunction();
199-
#if LLVM_VERSION_MAJOR >= 11
200-
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledOperand()))
201-
#else
202-
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue()))
203-
#endif
204-
{
205-
if (castinst->isCast()) {
206-
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
207-
called = fn;
208-
}
209-
}
210-
}
211-
212198
// If this is a known allocation which is not captured or returned,
213199
// a caller function cannot overwrite this (since it cannot access).
214200
// Since we don't currently perform this check, we can instead check

0 commit comments

Comments
 (0)