Skip to content

Commit 034b9b4

Browse files
authored
Fix sinc related bugs (rust-lang#814)
1 parent 2c667ea commit 034b9b4

File tree

5 files changed

+71
-2
lines changed

5 files changed

+71
-2
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,8 +2801,9 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
28012801
toret =
28022802
nBuilder.CreateInsertValue(toret, gutils->diffe(ret, nBuilder), 1);
28032803
} else {
2804-
toret = nBuilder.CreateInsertValue(
2805-
toret, Constant::getNullValue(ret->getType()), 1);
2804+
Type *retTy = gutils->getShadowType(ret->getType());
2805+
toret =
2806+
nBuilder.CreateInsertValue(toret, Constant::getNullValue(retTy), 1);
28062807
}
28072808
break;
28082809
}

enzyme/Enzyme/InstructionDerivatives.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,7 @@ def : CallPattern<(Op $x),
142142
["expm1"],
143143
[(FMul (Intrinsic<"exp", [(TypeOf<""> $x)]> $x), (DiffeRet<"">))]
144144
>;
145+
146+
def : CallPattern<(Op $x),
147+
["sinc", "sincf", "sincl"],
148+
[(FMul (DiffeRet<"">), (FDiv (FSub (Intrinsic<"cos", [(TypeOf<""> $x)]> $x), (Call<(SameFunc), [ReadNone,NoUnwind]> $x)), $x))]>;

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ llvm::cl::opt<bool> EnzymeStrictAliasing(
7373
}
7474

7575
const std::map<std::string, llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
76+
{"sinc", Intrinsic::not_intrinsic},
7677
{"cos", Intrinsic::cos},
7778
{"sin", Intrinsic::sin},
7879
{"tan", Intrinsic::not_intrinsic},
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -instsimplify -S | FileCheck %s
2+
3+
%struct.Gradients = type { double, double }
4+
5+
define double @f(double %x, i1 %c) {
6+
entry:
7+
%v = select i1 %c, double 0.000000e+00, double 1.000000e+00
8+
ret double %v
9+
}
10+
11+
define double @tester(double %x, double %y) {
12+
entry:
13+
%c = call double @f(double %x, i1 true)
14+
%mul = fmul double %c, %y
15+
ret double %mul
16+
}
17+
18+
define %struct.Gradients @test_derivative(double %x, double %y){
19+
entry:
20+
%call = call %struct.Gradients (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_width", i64 2, double %x, double 1.000000e+00, double 0.000000e+00, double %y, double 0.000000e+00, double 1.000000e+00)
21+
ret %struct.Gradients %call
22+
}
23+
24+
declare %struct.Gradients @__enzyme_fwddiff(double (double, double)*, ...)
25+
26+
; CHECK: define internal { double, [2 x double] } @fwddiffe2f(double %x, [2 x double] %"x'", i1 %c)
27+
; CHECK-NEXT: entry:
28+
; CHECK-NEXT: %v = select i1 %c, double 0.000000e+00, double 1.000000e+00
29+
; CHECK-NEXT: %0 = insertvalue { double, [2 x double] } undef, double %v, 0
30+
; CHECK-NEXT: %1 = insertvalue { double, [2 x double] } %0, [2 x double] zeroinitializer, 1
31+
; CHECK-NEXT: ret { double, [2 x double] } %1
32+
; CHECK-NEXT: }
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone uwtable
4+
define double @tester(double %x) {
5+
entry:
6+
%0 = tail call fast double @sinc(double %x)
7+
ret double %0
8+
}
9+
10+
define double @test_derivative(double %x) {
11+
entry:
12+
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
13+
ret double %0
14+
}
15+
16+
; Function Attrs: nounwind readnone speculatable
17+
declare double @sinc(double)
18+
19+
; Function Attrs: nounwind
20+
declare double @__enzyme_autodiff(double (double)*, ...)
21+
22+
; CHECK: define internal { double } @diffetester(double %x, double %differeturn)
23+
; CHECK-NEXT: entry:
24+
; CHECK-NEXT: %0 = call fast double @llvm.cos.f64(double %x)
25+
; CHECK-NEXT: %1 = call fast double @sinc(double %x)
26+
; CHECK-NEXT: %2 = fsub fast double %0, %1
27+
; CHECK-NEXT: %3 = fdiv fast double %2, %x
28+
; CHECK-NEXT: %4 = fmul fast double %differeturn, %3
29+
; CHECK-NEXT: %5 = insertvalue { double } undef, double %4, 0
30+
; CHECK-NEXT: ret { double } %5
31+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)