Skip to content

Commit f47e913

Browse files
authored
Correctly use original function in gradient for pow(i) (rust-lang#483)
1 parent 83628bf commit f47e913

File tree

1 file changed

+33
-41
lines changed

1 file changed

+33
-41
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3261,28 +3261,17 @@ class AdjointGenerator
32613261
lookup(op0, Builder2),
32623262
Builder2.CreateSub(lookup(op1, Builder2),
32633263
ConstantInt::get(op1->getType(), 1))};
3264-
Type *tys[] = {
3265-
orig_ops[0]->getType()
3266-
#if LLVM_VERSION_MAJOR >= 13
3267-
,
3268-
orig_ops[1]->getType()
3269-
#endif
3270-
};
32713264
auto &CI = cast<CallInst>(I);
32723265
#if LLVM_VERSION_MAJOR >= 11
32733266
auto *PowF = CI.getCalledOperand();
32743267
#else
32753268
auto *PowF = CI.getCalledValue();
32763269
#endif
3277-
if (!PowF)
3278-
PowF = Intrinsic::getDeclaration(M, Intrinsic::powi, tys);
3279-
3280-
auto FT = FunctionType::get(
3281-
I.getType(), {orig_ops[0]->getType(), orig_ops[1]->getType()},
3282-
false);
3270+
assert(PowF);
3271+
auto FT = cast<FunctionType>(
3272+
cast<PointerType>(PowF->getType())->getElementType());
32833273
auto cal = cast<CallInst>(Builder2.CreateCall(FT, PowF, args));
3284-
if (auto F = dyn_cast<Function>(PowF))
3285-
cal->setCallingConv(F->getCallingConv());
3274+
cal->setCallingConv(CI.getCallingConv());
32863275

32873276
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
32883277
Value *dif0 = Builder2.CreateFMul(
@@ -3301,12 +3290,9 @@ class AdjointGenerator
33013290
#else
33023291
auto *PowF = CI.getCalledValue();
33033292
#endif
3304-
if (!PowF)
3305-
PowF = Intrinsic::getDeclaration(M, Intrinsic::pow, tys);
3306-
3307-
auto FT = FunctionType::get(
3308-
I.getType(), {orig_ops[0]->getType(), orig_ops[1]->getType()},
3309-
false);
3293+
assert(PowF);
3294+
auto FT = cast<FunctionType>(
3295+
cast<PointerType>(PowF->getType())->getElementType());
33103296

33113297
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
33123298

@@ -3324,9 +3310,7 @@ class AdjointGenerator
33243310
Builder2.CreateFSub(lookup(op1, Builder2),
33253311
ConstantFP::get(I.getType(), 1.0))};
33263312
auto cal = cast<CallInst>(Builder2.CreateCall(FT, PowF, args));
3327-
if (auto F = dyn_cast<Function>(PowF))
3328-
cal->setCallingConv(F->getCallingConv());
3329-
3313+
cal->setCallingConv(CI.getCallingConv());
33303314
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
33313315

33323316
Value *dif0 = Builder2.CreateFMul(Builder2.CreateFMul(vdiff, cal),
@@ -3343,8 +3327,7 @@ class AdjointGenerator
33433327
lookup(gutils->getNewFromOriginal(orig_ops[1]), Builder2)};
33443328

33453329
cal = cast<CallInst>(Builder2.CreateCall(FT, PowF, args));
3346-
if (auto F = dyn_cast<Function>(PowF))
3347-
cal->setCallingConv(F->getCallingConv());
3330+
cal->setCallingConv(CI.getCallingConv());
33483331

33493332
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
33503333
}
@@ -3736,16 +3719,17 @@ class AdjointGenerator
37363719
SmallVector<Value *, 2> args = {
37373720
op0,
37383721
Builder2.CreateSub(op1, ConstantInt::get(op1->getType(), 1))};
3739-
Type *tys[] = {
3740-
orig_ops[0]->getType()
3741-
#if LLVM_VERSION_MAJOR >= 13
3742-
,
3743-
orig_ops[1]->getType()
3722+
auto &CI = cast<CallInst>(I);
3723+
#if LLVM_VERSION_MAJOR >= 11
3724+
auto *PowF = CI.getCalledOperand();
3725+
#else
3726+
auto *PowF = CI.getCalledValue();
37443727
#endif
3745-
};
3746-
Function *PowF = Intrinsic::getDeclaration(M, Intrinsic::powi, tys);
3747-
auto *cal = cast<CallInst>(Builder2.CreateCall(PowF, args));
3748-
cal->setCallingConv(PowF->getCallingConv());
3728+
assert(PowF);
3729+
auto FT = cast<FunctionType>(
3730+
cast<PointerType>(PowF->getType())->getElementType());
3731+
auto cal = cast<CallInst>(Builder2.CreateCall(FT, PowF, args));
3732+
cal->setCallingConv(CI.getCallingConv());
37493733
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
37503734

37513735
Value *cast =
@@ -3766,7 +3750,15 @@ class AdjointGenerator
37663750
return;
37673751

37683752
Type *tys[] = {orig_ops[0]->getType()};
3769-
Function *PowF = Intrinsic::getDeclaration(M, Intrinsic::pow, tys);
3753+
auto &CI = cast<CallInst>(I);
3754+
#if LLVM_VERSION_MAJOR >= 11
3755+
auto *PowF = CI.getCalledOperand();
3756+
#else
3757+
auto *PowF = CI.getCalledValue();
3758+
#endif
3759+
assert(PowF);
3760+
auto FT = cast<FunctionType>(
3761+
cast<PointerType>(PowF->getType())->getElementType());
37703762

37713763
Value *op0 = gutils->getNewFromOriginal(orig_ops[0]);
37723764
Value *op1 = gutils->getNewFromOriginal(orig_ops[1]);
@@ -3777,8 +3769,8 @@ class AdjointGenerator
37773769
if (!gutils->isConstantValue(orig_ops[0])) {
37783770
Value *args[2] = {
37793771
op0, Builder2.CreateFSub(op1, ConstantFP::get(I.getType(), 1.0))};
3780-
CallInst *powcall1 = cast<CallInst>(Builder2.CreateCall(PowF, args));
3781-
powcall1->setCallingConv(PowF->getCallingConv());
3772+
auto powcall1 = cast<CallInst>(Builder2.CreateCall(FT, PowF, args));
3773+
powcall1->setCallingConv(CI.getCallingConv());
37823774
powcall1->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
37833775

37843776
Value *mul = Builder2.CreateFMul(op1, powcall1);
@@ -3792,9 +3784,9 @@ class AdjointGenerator
37923784
res = applyChainRule(I.getType(), Builder2, rule, op, res);
37933785
}
37943786
if (!gutils->isConstantValue(orig_ops[1])) {
3795-
CallInst *powcall =
3796-
cast<CallInst>(Builder2.CreateCall(PowF, {op0, op1}));
3797-
powcall->setCallingConv(PowF->getCallingConv());
3787+
auto powcall =
3788+
cast<CallInst>(Builder2.CreateCall(FT, PowF, {op0, op1}));
3789+
powcall->setCallingConv(CI.getCallingConv());
37983790
powcall->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
37993791

38003792
CallInst *logcall = Builder2.CreateCall(

0 commit comments

Comments
 (0)