@@ -3261,28 +3261,17 @@ class AdjointGenerator
3261
3261
lookup (op0, Builder2),
3262
3262
Builder2.CreateSub (lookup (op1, Builder2),
3263
3263
ConstantInt::get (op1->getType (), 1 ))};
3264
- Type *tys[] = {
3265
- orig_ops[0 ]->getType ()
3266
- #if LLVM_VERSION_MAJOR >= 13
3267
- ,
3268
- orig_ops[1 ]->getType ()
3269
- #endif
3270
- };
3271
3264
auto &CI = cast<CallInst>(I);
3272
3265
#if LLVM_VERSION_MAJOR >= 11
3273
3266
auto *PowF = CI.getCalledOperand ();
3274
3267
#else
3275
3268
auto *PowF = CI.getCalledValue ();
3276
3269
#endif
3277
- if (!PowF)
3278
- PowF = Intrinsic::getDeclaration (M, Intrinsic::powi, tys);
3279
-
3280
- auto FT = FunctionType::get (
3281
- I.getType (), {orig_ops[0 ]->getType (), orig_ops[1 ]->getType ()},
3282
- false );
3270
+ assert (PowF);
3271
+ auto FT = cast<FunctionType>(
3272
+ cast<PointerType>(PowF->getType ())->getElementType ());
3283
3273
auto cal = cast<CallInst>(Builder2.CreateCall (FT, PowF, args));
3284
- if (auto F = dyn_cast<Function>(PowF))
3285
- cal->setCallingConv (F->getCallingConv ());
3274
+ cal->setCallingConv (CI.getCallingConv ());
3286
3275
3287
3276
cal->setDebugLoc (gutils->getNewFromOriginal (I.getDebugLoc ()));
3288
3277
Value *dif0 = Builder2.CreateFMul (
@@ -3301,12 +3290,9 @@ class AdjointGenerator
3301
3290
#else
3302
3291
auto *PowF = CI.getCalledValue ();
3303
3292
#endif
3304
- if (!PowF)
3305
- PowF = Intrinsic::getDeclaration (M, Intrinsic::pow, tys);
3306
-
3307
- auto FT = FunctionType::get (
3308
- I.getType (), {orig_ops[0 ]->getType (), orig_ops[1 ]->getType ()},
3309
- false );
3293
+ assert (PowF);
3294
+ auto FT = cast<FunctionType>(
3295
+ cast<PointerType>(PowF->getType ())->getElementType ());
3310
3296
3311
3297
if (vdiff && !gutils->isConstantValue (orig_ops[0 ])) {
3312
3298
@@ -3324,9 +3310,7 @@ class AdjointGenerator
3324
3310
Builder2.CreateFSub (lookup (op1, Builder2),
3325
3311
ConstantFP::get (I.getType (), 1.0 ))};
3326
3312
auto cal = cast<CallInst>(Builder2.CreateCall (FT, PowF, args));
3327
- if (auto F = dyn_cast<Function>(PowF))
3328
- cal->setCallingConv (F->getCallingConv ());
3329
-
3313
+ cal->setCallingConv (CI.getCallingConv ());
3330
3314
cal->setDebugLoc (gutils->getNewFromOriginal (I.getDebugLoc ()));
3331
3315
3332
3316
Value *dif0 = Builder2.CreateFMul (Builder2.CreateFMul (vdiff, cal),
@@ -3343,8 +3327,7 @@ class AdjointGenerator
3343
3327
lookup (gutils->getNewFromOriginal (orig_ops[1 ]), Builder2)};
3344
3328
3345
3329
cal = cast<CallInst>(Builder2.CreateCall (FT, PowF, args));
3346
- if (auto F = dyn_cast<Function>(PowF))
3347
- cal->setCallingConv (F->getCallingConv ());
3330
+ cal->setCallingConv (CI.getCallingConv ());
3348
3331
3349
3332
cal->setDebugLoc (gutils->getNewFromOriginal (I.getDebugLoc ()));
3350
3333
}
@@ -3736,16 +3719,17 @@ class AdjointGenerator
3736
3719
SmallVector<Value *, 2 > args = {
3737
3720
op0,
3738
3721
Builder2.CreateSub (op1, ConstantInt::get (op1->getType (), 1 ))};
3739
- Type *tys[] = {
3740
- orig_ops[ 0 ]-> getType ()
3741
- # if LLVM_VERSION_MAJOR >= 13
3742
- ,
3743
- orig_ops[ 1 ]-> getType ()
3722
+ auto &CI = cast<CallInst>(I);
3723
+ # if LLVM_VERSION_MAJOR >= 11
3724
+ auto *PowF = CI. getCalledOperand ();
3725
+ # else
3726
+ auto *PowF = CI. getCalledValue ();
3744
3727
#endif
3745
- };
3746
- Function *PowF = Intrinsic::getDeclaration (M, Intrinsic::powi, tys);
3747
- auto *cal = cast<CallInst>(Builder2.CreateCall (PowF, args));
3748
- cal->setCallingConv (PowF->getCallingConv ());
3728
+ assert (PowF);
3729
+ auto FT = cast<FunctionType>(
3730
+ cast<PointerType>(PowF->getType ())->getElementType ());
3731
+ auto cal = cast<CallInst>(Builder2.CreateCall (FT, PowF, args));
3732
+ cal->setCallingConv (CI.getCallingConv ());
3749
3733
cal->setDebugLoc (gutils->getNewFromOriginal (I.getDebugLoc ()));
3750
3734
3751
3735
Value *cast =
@@ -3766,7 +3750,15 @@ class AdjointGenerator
3766
3750
return ;
3767
3751
3768
3752
Type *tys[] = {orig_ops[0 ]->getType ()};
3769
- Function *PowF = Intrinsic::getDeclaration (M, Intrinsic::pow, tys);
3753
+ auto &CI = cast<CallInst>(I);
3754
+ #if LLVM_VERSION_MAJOR >= 11
3755
+ auto *PowF = CI.getCalledOperand ();
3756
+ #else
3757
+ auto *PowF = CI.getCalledValue ();
3758
+ #endif
3759
+ assert (PowF);
3760
+ auto FT = cast<FunctionType>(
3761
+ cast<PointerType>(PowF->getType ())->getElementType ());
3770
3762
3771
3763
Value *op0 = gutils->getNewFromOriginal (orig_ops[0 ]);
3772
3764
Value *op1 = gutils->getNewFromOriginal (orig_ops[1 ]);
@@ -3777,8 +3769,8 @@ class AdjointGenerator
3777
3769
if (!gutils->isConstantValue (orig_ops[0 ])) {
3778
3770
Value *args[2 ] = {
3779
3771
op0, Builder2.CreateFSub (op1, ConstantFP::get (I.getType (), 1.0 ))};
3780
- CallInst * powcall1 = cast<CallInst>(Builder2.CreateCall (PowF, args));
3781
- powcall1->setCallingConv (PowF-> getCallingConv ());
3772
+ auto powcall1 = cast<CallInst>(Builder2.CreateCall (FT, PowF, args));
3773
+ powcall1->setCallingConv (CI. getCallingConv ());
3782
3774
powcall1->setDebugLoc (gutils->getNewFromOriginal (I.getDebugLoc ()));
3783
3775
3784
3776
Value *mul = Builder2.CreateFMul (op1, powcall1);
@@ -3792,9 +3784,9 @@ class AdjointGenerator
3792
3784
res = applyChainRule (I.getType (), Builder2, rule, op, res);
3793
3785
}
3794
3786
if (!gutils->isConstantValue (orig_ops[1 ])) {
3795
- CallInst * powcall =
3796
- cast<CallInst>(Builder2.CreateCall (PowF, {op0, op1}));
3797
- powcall->setCallingConv (PowF-> getCallingConv ());
3787
+ auto powcall =
3788
+ cast<CallInst>(Builder2.CreateCall (FT, PowF, {op0, op1}));
3789
+ powcall->setCallingConv (CI. getCallingConv ());
3798
3790
powcall->setDebugLoc (gutils->getNewFromOriginal (I.getDebugLoc ()));
3799
3791
3800
3792
CallInst *logcall = Builder2.CreateCall (
0 commit comments