Skip to content

Commit 950150c

Browse files
committed
Custom Enzyme inactive marker
1 parent 0447c03 commit 950150c

File tree

3 files changed

+65
-0
lines changed

3 files changed

+65
-0
lines changed

Diff for: enzyme/Enzyme/ActivityAnalysis.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ const char *KnownInactiveFunctions[] = {"__assert_fail",
143143
/// This tool can only be used when in DOWN mode
144144
bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) {
145145
assert(directions & DOWN);
146+
147+
if (CI->hasFnAttr("enzyme_inactive")) {
148+
return true;
149+
}
150+
146151
Function *F = CI->getCalledFunction();
147152

148153
// Indirect function calls may actively use the argument
@@ -1133,11 +1138,15 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults &TR,
11331138

11341139
// Calls to print/assert/cxa guard are definitionally inactive
11351140
if (auto op = dyn_cast<CallInst>(inst)) {
1141+
if (op->hasFnAttr("enzyme_inactive")) {
1142+
return true;
1143+
}
11361144
if (auto called = op->getCalledFunction()) {
11371145
if (called->getName() == "free" || called->getName() == "_ZdlPv" ||
11381146
called->getName() == "_ZdlPvm" || called->getName() == "munmap") {
11391147
return true;
11401148
}
1149+
11411150
for (auto FuncName : KnownInactiveFunctionsStartingWith) {
11421151
if (called->getName().startswith(FuncName)) {
11431152
return true;

Diff for: enzyme/test/Enzyme/enzyme_inactive.ll

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline nounwind readnone uwtable
4+
define double @tester(double %x) {
5+
entry:
6+
tail call void @myprint(double %x)
7+
ret double %x
8+
}
9+
10+
define double @test_derivative(double %x) {
11+
entry:
12+
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
13+
ret double %0
14+
}
15+
16+
declare void @myprint(double %x) #0
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_autodiff(double (double)*, ...)
20+
21+
attributes #0 = { "enzyme_inactive" }
22+
23+
; CHECK: define internal {{(dso_local )?}}{ double } @diffetester(double %x, double %[[differet:.+]])
24+
; CHECK-NEXT: entry:
25+
; CHECK-NEXT: tail call void @myprint(double %x)
26+
; CHECK-NEXT: %0 = insertvalue { double } undef, double %differeturn, 0
27+
; CHECK-NEXT: ret { double } %0
28+
; CHECK-NEXT: }

Diff for: enzyme/test/Enzyme/enzyme_inactive2.ll

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline nounwind readnone uwtable
4+
define double @tester(double %x) {
5+
entry:
6+
tail call void @myprint(double %x) #0
7+
ret double %x
8+
}
9+
10+
define double @test_derivative(double %x) {
11+
entry:
12+
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
13+
ret double %0
14+
}
15+
16+
declare void @myprint(double %x)
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_autodiff(double (double)*, ...)
20+
21+
attributes #0 = { "enzyme_inactive" }
22+
23+
; CHECK: define internal {{(dso_local )?}}{ double } @diffetester(double %x, double %[[differet:.+]])
24+
; CHECK-NEXT: entry:
25+
; CHECK-NEXT: tail call void @myprint(double %x)
26+
; CHECK-NEXT: %0 = insertvalue { double } undef, double %differeturn, 0
27+
; CHECK-NEXT: ret { double } %0
28+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)