Skip to content

Commit 435ad32

Browse files
authored
Julia vector fixes (rust-lang#620)
* Julia vector fixes * Add dnrm * Fix vector addToPtrDiffe * Add width argument to augmented fwd * Fix augmented callingconv * API shadow
1 parent e2afba2 commit 435ad32

File tree

9 files changed

+238
-50
lines changed

9 files changed

+238
-50
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 186 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4386,6 +4386,7 @@ class AdjointGenerator
43864386
cast<Function>(called), subretType, argsInverted,
43874387
TR.analyzer.interprocedural, /*return is used*/ false,
43884388
/*shadowReturnUsed*/ false, nextTypeInfo, uncacheable_args, false,
4389+
gutils->getWidth(),
43894390
/*AtomicAdd*/ true,
43904391
/*OpenMP*/ true);
43914392
if (Mode == DerivativeMode::ReverseModePrimal) {
@@ -4981,7 +4982,7 @@ class AdjointGenerator
49814982

49824983
std::string extractBLAS(StringRef in, std::string &prefix,
49834984
std::string &suffix) {
4984-
std::string extractable[] = {"ddot", "sdot"};
4985+
std::string extractable[] = {"ddot", "sdot", "dnrm2", "snrm2"};
49854986
std::string prefixes[] = {"", "cblas_", "cublas_"};
49864987
std::string suffixes[] = {"", "_", "_64_"};
49874988
for (auto ex : extractable) {
@@ -5001,19 +5002,175 @@ class AdjointGenerator
50015002
bool handleBLAS(llvm::CallInst &call, Function *called, StringRef funcName,
50025003
StringRef prefix, StringRef suffix,
50035004
const std::map<Argument *, bool> &uncacheable_args) {
5004-
// Forward Mode not handled yet
5005-
assert(Mode != DerivativeMode::ForwardMode &&
5006-
Mode != DerivativeMode::ForwardModeSplit);
5007-
// Vector Mode not handled yet
5008-
assert(gutils->getWidth() == 1);
50095005
CallInst *const newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
50105006
IRBuilder<> BuilderZ(newCall);
50115007
BuilderZ.setFastMathFlags(getFast());
50125008
IRBuilder<> allocationBuilder(gutils->inversionAllocs);
50135009
allocationBuilder.setFastMathFlags(getFast());
50145010

5011+
if (funcName == "dnrm2" || funcName == "snrm2") {
5012+
if (!gutils->isConstantInstruction(&call)) {
5013+
5014+
Type *innerType;
5015+
std::string dfuncName;
5016+
if (funcName == "dnrm2") {
5017+
innerType = Type::getDoubleTy(call.getContext());
5018+
dfuncName = (prefix + "ddot" + suffix).str();
5019+
} else if (funcName == "snrm2") {
5020+
innerType = Type::getFloatTy(call.getContext());
5021+
dfuncName = (prefix + "sdot" + suffix).str();
5022+
} else {
5023+
assert(false && "Unreachable");
5024+
}
5025+
5026+
IntegerType *intType =
5027+
dyn_cast<IntegerType>(call.getOperand(0)->getType());
5028+
bool byRef = false;
5029+
if (!intType) {
5030+
auto PT = cast<PointerType>(call.getOperand(0)->getType());
5031+
if (suffix.contains("64"))
5032+
intType = IntegerType::get(PT->getContext(), 64);
5033+
else
5034+
intType = IntegerType::get(PT->getContext(), 32);
5035+
byRef = true;
5036+
}
5037+
5038+
// Non-forward Mode not handled yet
5039+
if (Mode != DerivativeMode::ForwardMode) {
5040+
return false;
5041+
} else {
5042+
Type *castval;
5043+
if (auto PT = dyn_cast<PointerType>(call.getArgOperand(1)->getType()))
5044+
castval = PT;
5045+
else
5046+
castval = PointerType::getUnqual(innerType);
5047+
5048+
auto in_arg = call.getCalledFunction()->arg_begin();
5049+
Argument *n = in_arg;
5050+
in_arg++;
5051+
Argument *x = in_arg;
5052+
in_arg++;
5053+
Argument *xinc = in_arg;
5054+
5055+
auto derivcall = gutils->oldFunc->getParent()->getOrInsertFunction(
5056+
dfuncName, innerType, n->getType(), x->getType(), xinc->getType(),
5057+
x->getType(), xinc->getType());
5058+
5059+
#if LLVM_VERSION_MAJOR >= 9
5060+
if (auto F = dyn_cast<Function>(derivcall.getCallee()))
5061+
#else
5062+
if (auto F = dyn_cast<Function>(derivcall))
5063+
#endif
5064+
{
5065+
F->addFnAttr(Attribute::ArgMemOnly);
5066+
F->addFnAttr(Attribute::ReadOnly);
5067+
if (byRef) {
5068+
F->addParamAttr(0, Attribute::ReadOnly);
5069+
F->addParamAttr(0, Attribute::NoCapture);
5070+
F->addParamAttr(2, Attribute::ReadOnly);
5071+
F->addParamAttr(2, Attribute::NoCapture);
5072+
F->addParamAttr(4, Attribute::ReadOnly);
5073+
F->addParamAttr(4, Attribute::NoCapture);
5074+
}
5075+
if (call.getArgOperand(1)->getType()->isPointerTy()) {
5076+
F->addParamAttr(1, Attribute::ReadOnly);
5077+
F->addParamAttr(1, Attribute::NoCapture);
5078+
F->addParamAttr(3, Attribute::ReadOnly);
5079+
F->addParamAttr(3, Attribute::NoCapture);
5080+
}
5081+
}
5082+
5083+
if (!gutils->isConstantValue(&call)) {
5084+
if (gutils->isConstantValue(call.getOperand(1))) {
5085+
setDiffe(
5086+
&call,
5087+
Constant::getNullValue(gutils->getShadowType(call.getType())),
5088+
BuilderZ);
5089+
} else {
5090+
auto Defs = gutils->getInvertedBundles(
5091+
&call,
5092+
{ValueType::Primal, ValueType::Primal, ValueType::Primal},
5093+
BuilderZ, /*lookup*/ false);
5094+
5095+
#if LLVM_VERSION_MAJOR >= 11
5096+
auto callval = call.getCalledOperand();
5097+
#else
5098+
auto callval = call.getCalledValue();
5099+
#endif
5100+
5101+
if (auto F = dyn_cast<Function>(callval)) {
5102+
F->addFnAttr(Attribute::ArgMemOnly);
5103+
F->addFnAttr(Attribute::ReadOnly);
5104+
if (byRef) {
5105+
F->addParamAttr(0, Attribute::ReadOnly);
5106+
F->addParamAttr(0, Attribute::NoCapture);
5107+
F->addParamAttr(2, Attribute::ReadOnly);
5108+
F->addParamAttr(2, Attribute::NoCapture);
5109+
}
5110+
if (call.getArgOperand(1)->getType()->isPointerTy()) {
5111+
F->addParamAttr(1, Attribute::ReadOnly);
5112+
F->addParamAttr(1, Attribute::NoCapture);
5113+
}
5114+
}
5115+
5116+
Value *args[] = {gutils->getNewFromOriginal(call.getOperand(0)),
5117+
gutils->getNewFromOriginal(call.getOperand(1)),
5118+
gutils->getNewFromOriginal(call.getOperand(2))};
5119+
5120+
#if LLVM_VERSION_MAJOR > 7
5121+
auto norm = BuilderZ.CreateCall(call.getFunctionType(), callval,
5122+
args, Defs);
5123+
#else
5124+
auto norm = BuilderZ.CreateCall(callval, args, Defs);
5125+
#endif
5126+
5127+
Value *dval = applyChainRule(
5128+
call.getType(), BuilderZ,
5129+
[&](Value *ip) {
5130+
Value *args1[] = {
5131+
gutils->getNewFromOriginal(call.getOperand(0)),
5132+
gutils->getNewFromOriginal(call.getOperand(1)),
5133+
gutils->getNewFromOriginal(call.getOperand(2)), ip,
5134+
gutils->getNewFromOriginal(call.getOperand(2))};
5135+
return BuilderZ.CreateFDiv(
5136+
BuilderZ.CreateCall(
5137+
derivcall, args1,
5138+
gutils->getInvertedBundles(
5139+
&call,
5140+
{ValueType::Primal, ValueType::Both,
5141+
ValueType::Primal},
5142+
BuilderZ, /*lookup*/ false)),
5143+
norm);
5144+
},
5145+
gutils->invertPointerM(call.getOperand(1), BuilderZ));
5146+
setDiffe(&call, dval, BuilderZ);
5147+
}
5148+
}
5149+
}
5150+
5151+
if (gutils->knownRecomputeHeuristic.find(&call) !=
5152+
gutils->knownRecomputeHeuristic.end()) {
5153+
if (!gutils->knownRecomputeHeuristic[&call]) {
5154+
gutils->cacheForReverse(BuilderZ, newCall,
5155+
getIndex(&call, CacheType::Self));
5156+
}
5157+
}
5158+
}
5159+
5160+
if (Mode == DerivativeMode::ReverseModeGradient) {
5161+
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
5162+
} else {
5163+
eraseIfUnused(call);
5164+
}
5165+
return true;
5166+
}
50155167
if (funcName == "ddot" || funcName == "sdot") {
50165168
if (!gutils->isConstantInstruction(&call)) {
5169+
// Forward Mode not handled yet
5170+
assert(Mode != DerivativeMode::ForwardMode &&
5171+
Mode != DerivativeMode::ForwardModeSplit);
5172+
// Vector Mode not handled yet
5173+
assert(gutils->getWidth() == 1);
50175174
Type *innerType;
50185175
std::string dfuncName;
50195176
if (funcName == "ddot") {
@@ -8037,7 +8194,8 @@ class AdjointGenerator
80378194
if (ifound != gutils->invertedPointers.end()) {
80388195
auto placeholder = cast<PHINode>(&*ifound->second);
80398196
if (invertedReturn && invertedReturn != placeholder) {
8040-
if (invertedReturn->getType() != orig->getType()) {
8197+
if (invertedReturn->getType() !=
8198+
gutils->getShadowType(orig->getType())) {
80418199
llvm::errs() << " o: " << *orig << "\n";
80428200
llvm::errs() << " ot: " << *orig->getType() << "\n";
80438201
llvm::errs() << " ir: " << *invertedReturn << "\n";
@@ -8047,7 +8205,8 @@ class AdjointGenerator
80478205
llvm::errs() << " newCall: " << *newCall << "\n";
80488206
llvm::errs() << " newCallT: " << *newCall->getType() << "\n";
80498207
}
8050-
assert(invertedReturn->getType() == orig->getType());
8208+
assert(invertedReturn->getType() ==
8209+
gutils->getShadowType(orig->getType()));
80518210
placeholder->replaceAllUsesWith(invertedReturn);
80528211
gutils->erase(placeholder);
80538212
gutils->invertedPointers.insert(
@@ -8128,7 +8287,8 @@ class AdjointGenerator
81288287
gutils->erase(placeholder);
81298288
} else {
81308289
if (invertedReturn && invertedReturn != placeholder) {
8131-
if (invertedReturn->getType() != orig->getType()) {
8290+
if (invertedReturn->getType() !=
8291+
gutils->getShadowType(orig->getType())) {
81328292
llvm::errs() << " o: " << *orig << "\n";
81338293
llvm::errs() << " ot: " << *orig->getType() << "\n";
81348294
llvm::errs() << " ir: " << *invertedReturn << "\n";
@@ -8138,7 +8298,8 @@ class AdjointGenerator
81388298
llvm::errs() << " newCall: " << *newCall << "\n";
81398299
llvm::errs() << " newCallT: " << *newCall->getType() << "\n";
81408300
}
8141-
assert(invertedReturn->getType() == orig->getType());
8301+
assert(invertedReturn->getType() ==
8302+
gutils->getShadowType(orig->getType()));
81428303
placeholder->replaceAllUsesWith(invertedReturn);
81438304
gutils->erase(placeholder);
81448305
} else
@@ -9839,8 +10000,15 @@ class AdjointGenerator
983910000
forwardsShadow) ||
984010001
(Mode == DerivativeMode::ReverseModeGradient &&
984110002
backwardsShadow)) {
9842-
anti = shadowHandlers[called->getName().str()](bb, orig, args);
9843-
10003+
anti = applyChainRule(call.getType(), bb, [&]() {
10004+
return shadowHandlers[called->getName().str()](bb, orig,
10005+
args);
10006+
});
10007+
if (anti->getType() != placeholder->getType()) {
10008+
llvm::errs() << "orig: " << *orig << "\n";
10009+
llvm::errs() << "placeholder: " << *placeholder << "\n";
10010+
llvm::errs() << "anti: " << *anti << "\n";
10011+
}
984410012
gutils->invertedPointers.erase(found);
984510013
bb.SetInsertPoint(placeholder);
984610014

@@ -10263,7 +10431,11 @@ class AdjointGenerator
1026310431

1026410432
Value *ptrshadow =
1026510433
gutils->invertPointerM(call.getArgOperand(0), BuilderZ);
10266-
Value *val = BuilderZ.CreateCall(called, {ptrshadow});
10434+
10435+
Value *val = applyChainRule(
10436+
call.getType(), BuilderZ,
10437+
[&](Value *v) -> Value * { return BuilderZ.CreateCall(called, {v}); },
10438+
ptrshadow);
1026710439

1026810440
gutils->replaceAWithB(placeholder, val);
1026910441
gutils->erase(placeholder);
@@ -10973,7 +11145,7 @@ class AdjointGenerator
1097311145
cast<Function>(called), subretType, argsInverted,
1097411146
TR.analyzer.interprocedural, /*return is used*/ subretused,
1097511147
shadowReturnUsed, nextTypeInfo, uncacheable_args, false,
10976-
gutils->AtomicAdd);
11148+
gutils->getWidth(), gutils->AtomicAdd);
1097711149
if (Mode == DerivativeMode::ReverseModePrimal) {
1097811150
assert(augmentedReturn);
1097911151
auto subaugmentations =

enzyme/Enzyme/CApi.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,13 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
452452
},
453453
eunwrap(TA), eunwrap(augmented)));
454454
}
455-
EnzymeAugmentedReturnPtr
456-
EnzymeCreateAugmentedPrimal(EnzymeLogicRef Logic, LLVMValueRef todiff,
457-
CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
458-
size_t constant_args_size, EnzymeTypeAnalysisRef TA,
459-
uint8_t returnUsed, uint8_t shadowReturnUsed,
460-
CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
461-
size_t uncacheable_args_size,
462-
uint8_t forceAnonymousTape, uint8_t AtomicAdd) {
455+
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
456+
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
457+
CDIFFE_TYPE *constant_args, size_t constant_args_size,
458+
EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed,
459+
CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
460+
size_t uncacheable_args_size, uint8_t forceAnonymousTape, unsigned width,
461+
uint8_t AtomicAdd) {
463462

464463
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
465464
(DIFFE_TYPE *)constant_args +
@@ -475,7 +474,7 @@ EnzymeCreateAugmentedPrimal(EnzymeLogicRef Logic, LLVMValueRef todiff,
475474
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
476475
eunwrap(TA), returnUsed, shadowReturnUsed,
477476
eunwrap(typeInfo, cast<Function>(unwrap(todiff))), uncacheable_args,
478-
forceAnonymousTape, AtomicAdd));
477+
forceAnonymousTape, width, AtomicAdd));
479478
}
480479

481480
LLVMValueRef

enzyme/Enzyme/CApi.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
139139
CDIFFE_TYPE *constant_args, size_t constant_args_size,
140140
EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed,
141141
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
142-
size_t uncacheable_args_size, uint8_t forceAnonymousTape,
142+
size_t uncacheable_args_size, uint8_t forceAnonymousTape, unsigned width,
143143
uint8_t AtomicAdd);
144144

145145
typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/,

enzyme/Enzyme/Enzyme.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ class Enzyme : public ModulePass {
878878
aug = &Logic.CreateAugmentedPrimal(
879879
cast<Function>(fn), retType, constants, TA,
880880
/*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args,
881-
volatile_args, forceAnonymousTape, /*atomicAdd*/ AtomicAdd);
881+
volatile_args, forceAnonymousTape, width, /*atomicAdd*/ AtomicAdd);
882882
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
883883
if (!forceAnonymousTape) {
884884
assert(!aug->tapeType);
@@ -941,7 +941,7 @@ class Enzyme : public ModulePass {
941941
retType == DIFFE_TYPE::DUP_NONEED);
942942
aug = &Logic.CreateAugmentedPrimal(
943943
cast<Function>(fn), retType, constants, TA, returnUsed,
944-
shadowReturnUsed, type_args, volatile_args, forceAnonymousTape,
944+
shadowReturnUsed, type_args, volatile_args, forceAnonymousTape, width,
945945
/*atomicAdd*/ AtomicAdd);
946946
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
947947
if (!forceAnonymousTape) {

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,7 +1640,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
16401640
const std::vector<DIFFE_TYPE> &constant_args, TypeAnalysis &TA,
16411641
bool returnUsed, bool shadowReturnUsed, const FnTypeInfo &oldTypeInfo_,
16421642
const std::map<Argument *, bool> _uncacheable_args, bool forceAnonymousTape,
1643-
bool AtomicAdd, bool omp) {
1643+
unsigned width, bool AtomicAdd, bool omp) {
16441644
if (returnUsed)
16451645
assert(!todiff->getReturnType()->isEmptyTy() &&
16461646
!todiff->getReturnType()->isVoidTy());
@@ -1656,7 +1656,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
16561656
std::map<Argument *, bool>(_uncacheable_args.begin(),
16571657
_uncacheable_args.end()),
16581658
returnUsed, shadowReturnUsed, oldTypeInfo,
1659-
forceAnonymousTape, AtomicAdd, omp);
1659+
forceAnonymousTape, AtomicAdd, omp, width);
16601660
auto found = AugmentedCachedFunctions.find(tup);
16611661
if (found != AugmentedCachedFunctions.end()) {
16621662
return found->second;
@@ -1745,7 +1745,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
17451745
}
17461746
auto &aug = CreateAugmentedPrimal(
17471747
todiff, retType, next_constant_args, TA, returnUsed, shadowReturnUsed,
1748-
oldTypeInfo_, _uncacheable_args, forceAnonymousTape, AtomicAdd, omp);
1748+
oldTypeInfo_, _uncacheable_args, forceAnonymousTape, width, AtomicAdd,
1749+
omp);
17491750
auto cal = bb.CreateCall(aug.fn, fwdargs);
17501751
cal->setCallingConv(aug.fn->getCallingConv());
17511752

@@ -1835,7 +1836,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
18351836
std::map<AugmentedStruct, int> returnMapping;
18361837

18371838
GradientUtils *gutils = GradientUtils::CreateFromClone(
1838-
*this, todiff, TLI, TA, retType, constant_args,
1839+
*this, width, todiff, TLI, TA, retType, constant_args,
18391840
/*returnUsed*/ returnUsed, /*shadowReturnUsed*/ shadowReturnUsed,
18401841
returnMapping, omp);
18411842
gutils->AtomicAdd = AtomicAdd;
@@ -3052,7 +3053,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
30523053
auto &aug = CreateAugmentedPrimal(
30533054
key.todiff, key.retType, key.constant_args, TA, key.returnUsed,
30543055
key.shadowReturnUsed, key.typeInfo, key.uncacheable_args,
3055-
/*forceAnonymousTape*/ false, key.AtomicAdd, omp);
3056+
/*forceAnonymousTape*/ false, key.width, key.AtomicAdd, omp);
30563057

30573058
SmallVector<Value *, 4> fwdargs;
30583059
for (auto &a : NewF->args())

0 commit comments

Comments
 (0)