Skip to content

Commit 3aef65b

Browse files
authored
Handle cbrt (rust-lang#235)
1 parent e5d4699 commit 3aef65b

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3845,6 +3845,38 @@ class AdjointGenerator
38453845
return;
38463846
}
38473847

3848+
if (funcName == "cbrt") {
3849+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3850+
gutils->knownRecomputeHeuristic.end()) {
3851+
if (!gutils->knownRecomputeHeuristic[orig]) {
3852+
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3853+
getIndex(orig, CacheType::Self));
3854+
}
3855+
}
3856+
eraseIfUnused(*orig);
3857+
if (Mode == DerivativeMode::ReverseModePrimal ||
3858+
gutils->isConstantInstruction(orig))
3859+
return;
3860+
3861+
IRBuilder<> Builder2(call.getParent());
3862+
getReverseBuilder(Builder2);
3863+
Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)),
3864+
Builder2);
3865+
Value *args[] = {x};
3866+
#if LLVM_VERSION_MAJOR >= 11
3867+
auto callval = orig->getCalledOperand();
3868+
#else
3869+
auto callval = orig->getCalledValue();
3870+
#endif
3871+
Value *dif0 = Builder2.CreateFDiv(
3872+
Builder2.CreateFMul(diffe(orig, Builder2), x),
3873+
Builder2.CreateFMul(
3874+
ConstantFP::get(x->getType(), 3),
3875+
Builder2.CreateCall(orig->getFunctionType(), callval, args)));
3876+
addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType());
3877+
return;
3878+
}
3879+
38483880
if (funcName == "tanhf" || funcName == "tanh") {
38493881
if (gutils->knownRecomputeHeuristic.find(orig) !=
38503882
gutils->knownRecomputeHeuristic.end()) {
@@ -4201,9 +4233,9 @@ class AdjointGenerator
42014233
getIndex(orig, CacheType::Self));
42024234
}
42034235
}
4236+
eraseIfUnused(*orig);
42044237
if (Mode == DerivativeMode::ReverseModePrimal ||
42054238
gutils->isConstantInstruction(orig)) {
4206-
eraseIfUnused(*orig);
42074239
return;
42084240
}
42094241

@@ -4244,9 +4276,9 @@ class AdjointGenerator
42444276
getIndex(orig, CacheType::Self));
42454277
}
42464278
}
4279+
eraseIfUnused(*orig);
42474280
if (Mode == DerivativeMode::ReverseModePrimal ||
42484281
gutils->isConstantInstruction(orig)) {
4249-
eraseIfUnused(*orig);
42504282
return;
42514283
}
42524284

@@ -4288,9 +4320,9 @@ class AdjointGenerator
42884320
getIndex(orig, CacheType::Self));
42894321
}
42904322
}
4323+
eraseIfUnused(*orig);
42914324
if (Mode == DerivativeMode::ReverseModePrimal ||
42924325
gutils->isConstantInstruction(orig)) {
4293-
eraseIfUnused(*orig);
42944326
return;
42954327
}
42964328

enzyme/test/Enzyme/ReverseMode/cabs.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@ declare double @__enzyme_autodiff(double (double, double)*, ...)
2020

2121
; CHECK: define internal { double, double } @diffetester(double %x, double %y, double %differeturn) {
2222
; CHECK-NEXT: entry:
23-
; CHECK-NEXT: %call = call double @cabs(double %x, double %y)
2423
; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y)
2524
; CHECK-NEXT: %1 = fdiv fast double %differeturn, %0
2625
; CHECK-NEXT: %2 = fmul fast double %x, %1
2726
; CHECK-NEXT: %3 = fmul fast double %y, %1
2827
; CHECK-NEXT: %4 = insertvalue { double, double } undef, double %2, 0
2928
; CHECK-NEXT: %5 = insertvalue { double, double } %4, double %3, 1
3029
; CHECK-NEXT: ret { double, double } %5
31-
; CHECK-NEXT: }
30+
; CHECK-NEXT: }
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone uwtable
4+
define double @tester(double %x) {
5+
entry:
6+
%call = call double @cbrt(double %x)
7+
ret double %call
8+
}
9+
10+
define double @test_derivative(double %x) {
11+
entry:
12+
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
13+
ret double %0
14+
}
15+
16+
declare double @cbrt(double)
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_autodiff(double (double)*, ...)
20+
21+
; CHECK: define internal { double } @diffetester(double %x, double %differeturn) {
22+
; CHECK-NEXT: entry:
23+
; CHECK-NEXT: %0 = call fast double @cbrt(double %x)
24+
; CHECK-NEXT: %1 = fmul fast double 3.000000e+00, %0
25+
; CHECK-NEXT: %2 = fmul fast double %differeturn, %x
26+
; CHECK-NEXT: %3 = fdiv fast double %2, %1
27+
; CHECK-NEXT: %4 = insertvalue { double } undef, double %3, 0
28+
; CHECK-NEXT: ret { double } %4
29+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)