@@ -4386,6 +4386,7 @@ class AdjointGenerator
4386
4386
cast<Function>(called), subretType, argsInverted,
4387
4387
TR.analyzer.interprocedural, /*return is used*/ false,
4388
4388
/*shadowReturnUsed*/ false, nextTypeInfo, uncacheable_args, false,
4389
+ gutils->getWidth(),
4389
4390
/*AtomicAdd*/ true,
4390
4391
/*OpenMP*/ true);
4391
4392
if (Mode == DerivativeMode::ReverseModePrimal) {
@@ -4981,7 +4982,7 @@ class AdjointGenerator
4981
4982
4982
4983
std::string extractBLAS(StringRef in, std::string &prefix,
4983
4984
std::string &suffix) {
4984
- std::string extractable[] = {"ddot", "sdot"};
4985
+ std::string extractable[] = {"ddot", "sdot", "dnrm2", "snrm2" };
4985
4986
std::string prefixes[] = {"", "cblas_", "cublas_"};
4986
4987
std::string suffixes[] = {"", "_", "_64_"};
4987
4988
for (auto ex : extractable) {
@@ -5001,19 +5002,175 @@ class AdjointGenerator
5001
5002
bool handleBLAS(llvm::CallInst &call, Function *called, StringRef funcName,
5002
5003
StringRef prefix, StringRef suffix,
5003
5004
const std::map<Argument *, bool> &uncacheable_args) {
5004
- // Forward Mode not handled yet
5005
- assert(Mode != DerivativeMode::ForwardMode &&
5006
- Mode != DerivativeMode::ForwardModeSplit);
5007
- // Vector Mode not handled yet
5008
- assert(gutils->getWidth() == 1);
5009
5005
CallInst *const newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
5010
5006
IRBuilder<> BuilderZ(newCall);
5011
5007
BuilderZ.setFastMathFlags(getFast());
5012
5008
IRBuilder<> allocationBuilder(gutils->inversionAllocs);
5013
5009
allocationBuilder.setFastMathFlags(getFast());
5014
5010
5011
+ if (funcName == "dnrm2" || funcName == "snrm2") {
5012
+ if (!gutils->isConstantInstruction(&call)) {
5013
+
5014
+ Type *innerType;
5015
+ std::string dfuncName;
5016
+ if (funcName == "dnrm2") {
5017
+ innerType = Type::getDoubleTy(call.getContext());
5018
+ dfuncName = (prefix + "ddot" + suffix).str();
5019
+ } else if (funcName == "snrm2") {
5020
+ innerType = Type::getFloatTy(call.getContext());
5021
+ dfuncName = (prefix + "sdot" + suffix).str();
5022
+ } else {
5023
+ assert(false && "Unreachable");
5024
+ }
5025
+
5026
+ IntegerType *intType =
5027
+ dyn_cast<IntegerType>(call.getOperand(0)->getType());
5028
+ bool byRef = false;
5029
+ if (!intType) {
5030
+ auto PT = cast<PointerType>(call.getOperand(0)->getType());
5031
+ if (suffix.contains("64"))
5032
+ intType = IntegerType::get(PT->getContext(), 64);
5033
+ else
5034
+ intType = IntegerType::get(PT->getContext(), 32);
5035
+ byRef = true;
5036
+ }
5037
+
5038
+ // Non-forward Mode not handled yet
5039
+ if (Mode != DerivativeMode::ForwardMode) {
5040
+ return false;
5041
+ } else {
5042
+ Type *castval;
5043
+ if (auto PT = dyn_cast<PointerType>(call.getArgOperand(1)->getType()))
5044
+ castval = PT;
5045
+ else
5046
+ castval = PointerType::getUnqual(innerType);
5047
+
5048
+ auto in_arg = call.getCalledFunction()->arg_begin();
5049
+ Argument *n = in_arg;
5050
+ in_arg++;
5051
+ Argument *x = in_arg;
5052
+ in_arg++;
5053
+ Argument *xinc = in_arg;
5054
+
5055
+ auto derivcall = gutils->oldFunc->getParent()->getOrInsertFunction(
5056
+ dfuncName, innerType, n->getType(), x->getType(), xinc->getType(),
5057
+ x->getType(), xinc->getType());
5058
+
5059
+ #if LLVM_VERSION_MAJOR >= 9
5060
+ if (auto F = dyn_cast<Function>(derivcall.getCallee()))
5061
+ #else
5062
+ if (auto F = dyn_cast<Function>(derivcall))
5063
+ #endif
5064
+ {
5065
+ F->addFnAttr(Attribute::ArgMemOnly);
5066
+ F->addFnAttr(Attribute::ReadOnly);
5067
+ if (byRef) {
5068
+ F->addParamAttr(0, Attribute::ReadOnly);
5069
+ F->addParamAttr(0, Attribute::NoCapture);
5070
+ F->addParamAttr(2, Attribute::ReadOnly);
5071
+ F->addParamAttr(2, Attribute::NoCapture);
5072
+ F->addParamAttr(4, Attribute::ReadOnly);
5073
+ F->addParamAttr(4, Attribute::NoCapture);
5074
+ }
5075
+ if (call.getArgOperand(1)->getType()->isPointerTy()) {
5076
+ F->addParamAttr(1, Attribute::ReadOnly);
5077
+ F->addParamAttr(1, Attribute::NoCapture);
5078
+ F->addParamAttr(3, Attribute::ReadOnly);
5079
+ F->addParamAttr(3, Attribute::NoCapture);
5080
+ }
5081
+ }
5082
+
5083
+ if (!gutils->isConstantValue(&call)) {
5084
+ if (gutils->isConstantValue(call.getOperand(1))) {
5085
+ setDiffe(
5086
+ &call,
5087
+ Constant::getNullValue(gutils->getShadowType(call.getType())),
5088
+ BuilderZ);
5089
+ } else {
5090
+ auto Defs = gutils->getInvertedBundles(
5091
+ &call,
5092
+ {ValueType::Primal, ValueType::Primal, ValueType::Primal},
5093
+ BuilderZ, /*lookup*/ false);
5094
+
5095
+ #if LLVM_VERSION_MAJOR >= 11
5096
+ auto callval = call.getCalledOperand();
5097
+ #else
5098
+ auto callval = call.getCalledValue();
5099
+ #endif
5100
+
5101
+ if (auto F = dyn_cast<Function>(callval)) {
5102
+ F->addFnAttr(Attribute::ArgMemOnly);
5103
+ F->addFnAttr(Attribute::ReadOnly);
5104
+ if (byRef) {
5105
+ F->addParamAttr(0, Attribute::ReadOnly);
5106
+ F->addParamAttr(0, Attribute::NoCapture);
5107
+ F->addParamAttr(2, Attribute::ReadOnly);
5108
+ F->addParamAttr(2, Attribute::NoCapture);
5109
+ }
5110
+ if (call.getArgOperand(1)->getType()->isPointerTy()) {
5111
+ F->addParamAttr(1, Attribute::ReadOnly);
5112
+ F->addParamAttr(1, Attribute::NoCapture);
5113
+ }
5114
+ }
5115
+
5116
+ Value *args[] = {gutils->getNewFromOriginal(call.getOperand(0)),
5117
+ gutils->getNewFromOriginal(call.getOperand(1)),
5118
+ gutils->getNewFromOriginal(call.getOperand(2))};
5119
+
5120
+ #if LLVM_VERSION_MAJOR > 7
5121
+ auto norm = BuilderZ.CreateCall(call.getFunctionType(), callval,
5122
+ args, Defs);
5123
+ #else
5124
+ auto norm = BuilderZ.CreateCall(callval, args, Defs);
5125
+ #endif
5126
+
5127
+ Value *dval = applyChainRule(
5128
+ call.getType(), BuilderZ,
5129
+ [&](Value *ip) {
5130
+ Value *args1[] = {
5131
+ gutils->getNewFromOriginal(call.getOperand(0)),
5132
+ gutils->getNewFromOriginal(call.getOperand(1)),
5133
+ gutils->getNewFromOriginal(call.getOperand(2)), ip,
5134
+ gutils->getNewFromOriginal(call.getOperand(2))};
5135
+ return BuilderZ.CreateFDiv(
5136
+ BuilderZ.CreateCall(
5137
+ derivcall, args1,
5138
+ gutils->getInvertedBundles(
5139
+ &call,
5140
+ {ValueType::Primal, ValueType::Both,
5141
+ ValueType::Primal},
5142
+ BuilderZ, /*lookup*/ false)),
5143
+ norm);
5144
+ },
5145
+ gutils->invertPointerM(call.getOperand(1), BuilderZ));
5146
+ setDiffe(&call, dval, BuilderZ);
5147
+ }
5148
+ }
5149
+ }
5150
+
5151
+ if (gutils->knownRecomputeHeuristic.find(&call) !=
5152
+ gutils->knownRecomputeHeuristic.end()) {
5153
+ if (!gutils->knownRecomputeHeuristic[&call]) {
5154
+ gutils->cacheForReverse(BuilderZ, newCall,
5155
+ getIndex(&call, CacheType::Self));
5156
+ }
5157
+ }
5158
+ }
5159
+
5160
+ if (Mode == DerivativeMode::ReverseModeGradient) {
5161
+ eraseIfUnused(call, /*erase*/ true, /*check*/ false);
5162
+ } else {
5163
+ eraseIfUnused(call);
5164
+ }
5165
+ return true;
5166
+ }
5015
5167
if (funcName == "ddot" || funcName == "sdot") {
5016
5168
if (!gutils->isConstantInstruction(&call)) {
5169
+ // Forward Mode not handled yet
5170
+ assert(Mode != DerivativeMode::ForwardMode &&
5171
+ Mode != DerivativeMode::ForwardModeSplit);
5172
+ // Vector Mode not handled yet
5173
+ assert(gutils->getWidth() == 1);
5017
5174
Type *innerType;
5018
5175
std::string dfuncName;
5019
5176
if (funcName == "ddot") {
@@ -8037,7 +8194,8 @@ class AdjointGenerator
8037
8194
if (ifound != gutils->invertedPointers.end()) {
8038
8195
auto placeholder = cast<PHINode>(&*ifound->second);
8039
8196
if (invertedReturn && invertedReturn != placeholder) {
8040
- if (invertedReturn->getType() != orig->getType()) {
8197
+ if (invertedReturn->getType() !=
8198
+ gutils->getShadowType(orig->getType())) {
8041
8199
llvm::errs() << " o: " << *orig << "\n";
8042
8200
llvm::errs() << " ot: " << *orig->getType() << "\n";
8043
8201
llvm::errs() << " ir: " << *invertedReturn << "\n";
@@ -8047,7 +8205,8 @@ class AdjointGenerator
8047
8205
llvm::errs() << " newCall: " << *newCall << "\n";
8048
8206
llvm::errs() << " newCallT: " << *newCall->getType() << "\n";
8049
8207
}
8050
- assert(invertedReturn->getType() == orig->getType());
8208
+ assert(invertedReturn->getType() ==
8209
+ gutils->getShadowType(orig->getType()));
8051
8210
placeholder->replaceAllUsesWith(invertedReturn);
8052
8211
gutils->erase(placeholder);
8053
8212
gutils->invertedPointers.insert(
@@ -8128,7 +8287,8 @@ class AdjointGenerator
8128
8287
gutils->erase(placeholder);
8129
8288
} else {
8130
8289
if (invertedReturn && invertedReturn != placeholder) {
8131
- if (invertedReturn->getType() != orig->getType()) {
8290
+ if (invertedReturn->getType() !=
8291
+ gutils->getShadowType(orig->getType())) {
8132
8292
llvm::errs() << " o: " << *orig << "\n";
8133
8293
llvm::errs() << " ot: " << *orig->getType() << "\n";
8134
8294
llvm::errs() << " ir: " << *invertedReturn << "\n";
@@ -8138,7 +8298,8 @@ class AdjointGenerator
8138
8298
llvm::errs() << " newCall: " << *newCall << "\n";
8139
8299
llvm::errs() << " newCallT: " << *newCall->getType() << "\n";
8140
8300
}
8141
- assert(invertedReturn->getType() == orig->getType());
8301
+ assert(invertedReturn->getType() ==
8302
+ gutils->getShadowType(orig->getType()));
8142
8303
placeholder->replaceAllUsesWith(invertedReturn);
8143
8304
gutils->erase(placeholder);
8144
8305
} else
@@ -9839,8 +10000,15 @@ class AdjointGenerator
9839
10000
forwardsShadow) ||
9840
10001
(Mode == DerivativeMode::ReverseModeGradient &&
9841
10002
backwardsShadow)) {
9842
- anti = shadowHandlers[called->getName().str()](bb, orig, args);
9843
-
10003
+ anti = applyChainRule(call.getType(), bb, [&]() {
10004
+ return shadowHandlers[called->getName().str()](bb, orig,
10005
+ args);
10006
+ });
10007
+ if (anti->getType() != placeholder->getType()) {
10008
+ llvm::errs() << "orig: " << *orig << "\n";
10009
+ llvm::errs() << "placeholder: " << *placeholder << "\n";
10010
+ llvm::errs() << "anti: " << *anti << "\n";
10011
+ }
9844
10012
gutils->invertedPointers.erase(found);
9845
10013
bb.SetInsertPoint(placeholder);
9846
10014
@@ -10263,7 +10431,11 @@ class AdjointGenerator
10263
10431
10264
10432
Value *ptrshadow =
10265
10433
gutils->invertPointerM(call.getArgOperand(0), BuilderZ);
10266
- Value *val = BuilderZ.CreateCall(called, {ptrshadow});
10434
+
10435
+ Value *val = applyChainRule(
10436
+ call.getType(), BuilderZ,
10437
+ [&](Value *v) -> Value * { return BuilderZ.CreateCall(called, {v}); },
10438
+ ptrshadow);
10267
10439
10268
10440
gutils->replaceAWithB(placeholder, val);
10269
10441
gutils->erase(placeholder);
@@ -10973,7 +11145,7 @@ class AdjointGenerator
10973
11145
cast<Function>(called), subretType, argsInverted,
10974
11146
TR.analyzer.interprocedural, /*return is used*/ subretused,
10975
11147
shadowReturnUsed, nextTypeInfo, uncacheable_args, false,
10976
- gutils->AtomicAdd);
11148
+ gutils->getWidth(), gutils-> AtomicAdd);
10977
11149
if (Mode == DerivativeMode::ReverseModePrimal) {
10978
11150
assert(augmentedReturn);
10979
11151
auto subaugmentations =
0 commit comments