Skip to content

Commit 2184096

Browse files
authored
Use provided sqrt (rust-lang#533)
1 parent 26d31e0 commit 2184096

File tree

3 files changed

+24
-17
lines changed

3 files changed

+24
-17
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3129,15 +3129,19 @@ class AdjointGenerator
31293129
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
31303130
SmallVector<Value *, 2> args = {
31313131
lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2)};
3132-
Type *tys[] = {orig_ops[0]->getType()};
3133-
Function *SqrtF;
3134-
if (ID == Intrinsic::sqrt)
3135-
SqrtF = Intrinsic::getDeclaration(M, ID, tys);
3136-
else
3137-
SqrtF = Intrinsic::getDeclaration(M, ID);
31383132

3139-
auto cal = cast<CallInst>(Builder2.CreateCall(SqrtF, args));
3140-
cal->setCallingConv(SqrtF->getCallingConv());
3133+
auto &CI = cast<CallInst>(I);
3134+
#if LLVM_VERSION_MAJOR >= 11
3135+
auto *SqrtF = CI.getCalledOperand();
3136+
#else
3137+
auto *SqrtF = CI.getCalledValue();
3138+
#endif
3139+
assert(SqrtF);
3140+
auto FT =
3141+
cast<FunctionType>(SqrtF->getType()->getPointerElementType());
3142+
3143+
auto cal = cast<CallInst>(Builder2.CreateCall(FT, SqrtF, args));
3144+
cal->setCallingConv(CI.getCallingConv());
31413145
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
31423146

31433147
Value *dif0 = Builder2.CreateBinOp(
@@ -3529,15 +3533,18 @@ class AdjointGenerator
35293533
Value *args[1] = {gutils->getNewFromOriginal(orig_ops[0])};
35303534
Type *tys[] = {orig_ops[0]->getType()};
35313535

3532-
Function *SqrtF;
3533-
if (ID == Intrinsic::sqrt)
3534-
SqrtF = Intrinsic::getDeclaration(M, ID, tys);
3535-
else
3536-
SqrtF = Intrinsic::getDeclaration(M, ID);
3536+
auto &CI = cast<CallInst>(I);
3537+
#if LLVM_VERSION_MAJOR >= 11
3538+
auto *SqrtF = CI.getCalledOperand();
3539+
#else
3540+
auto *SqrtF = CI.getCalledValue();
3541+
#endif
3542+
assert(SqrtF);
3543+
auto FT = cast<FunctionType>(SqrtF->getType()->getPointerElementType());
35373544

35383545
auto rule = [&](Value *op) {
3539-
CallInst *cal = cast<CallInst>(Builder2.CreateCall(SqrtF, args));
3540-
cal->setCallingConv(SqrtF->getCallingConv());
3546+
CallInst *cal = cast<CallInst>(Builder2.CreateCall(FT, SqrtF, args));
3547+
cal->setCallingConv(CI.getCallingConv());
35413548
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
35423549

35433550
Value *half = ConstantFP::get(orig_ops[0]->getType(), 0.5);

enzyme/test/Enzyme/ReverseMode/ompsqloop.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ attributes #1 = { argmemonly }
210210
; CHECK-NEXT: %[[i9:.+]] = add nuw nsw i64 %"iv'ac.0", %_unwrap2
211211
; CHECK-NEXT: %[[i10:.+]] = getelementptr inbounds double, double* %truetape, i64 %[[i9]]
212212
; CHECK-NEXT: %[[i11:.+]] = load double, double* %[[i10]], align 8, !tbaa !9, !invariant.group !
213-
; CHECK-NEXT: %[[i12:.+]] = call fast double @llvm.sqrt.f64(double %[[i11]])
213+
; CHECK-NEXT: %[[i12:.+]] = call fast double @sqrt(double %[[i11]])
214214
; CHECK-NEXT: %[[i13:.+]] = fmul fast double 5.000000e-01, %[[i8]]
215215
; CHECK-NEXT: %[[i14:.+]] = fdiv fast double %[[i13]], %[[i12]]
216216
; CHECK-NEXT: %[[i15:.+]] = fcmp fast oeq double %[[i11]], 0.000000e+00

enzyme/test/Enzyme/ReverseMode/ompsqloopoutofplace.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ attributes #1 = { argmemonly }
155155
; CHECK-NEXT: store double 0.000000e+00, double* %"outidx'ipg_unwrap", align 8
156156
; CHECK-NEXT: %arrayidx_unwrap = getelementptr inbounds double, double* %tmp, i64 %_unwrap3
157157
; CHECK-NEXT: %_unwrap4 = load double, double* %arrayidx_unwrap, align 8, !tbaa !9, !invariant.group !16
158-
; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %_unwrap4)
158+
; CHECK-NEXT: %2 = call fast double @sqrt(double %_unwrap4)
159159
; CHECK-NEXT: %3 = fmul fast double 5.000000e-01, %1
160160
; CHECK-NEXT: %4 = fdiv fast double %3, %2
161161
; CHECK-NEXT: %5 = fcmp fast oeq double %_unwrap4, 0.000000e+00

0 commit comments

Comments
 (0)