Skip to content

Commit 7d90c8f

Browse files
wsmosestgymnich
andauthored
Custom Forward Pass (rust-lang#403)
* Fix allocs with more than one parameter (rust-lang#397) * fix allocation functions with more than one argument * add test * fix test * Handle custom forward * Handle integer memcpy in forward Co-authored-by: Tim Gymnich <[email protected]> Co-authored-by: Tim Gymnich <[email protected]>
1 parent 3872f89 commit 7d90c8f

File tree

7 files changed

+183
-91
lines changed

7 files changed

+183
-91
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 149 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,7 +2231,13 @@ class AdjointGenerator
22312231
IRBuilder<> Builder2(&MTI);
22322232
getForwardBuilder(Builder2);
22332233
auto ddst = gutils->invertPointerM(orig_dst, Builder2);
2234+
if (ddst->getType()->isIntegerTy())
2235+
ddst = Builder2.CreateIntToPtr(ddst,
2236+
Type::getInt8PtrTy(ddst->getContext()));
22342237
auto dsrc = gutils->invertPointerM(orig_src, Builder2);
2238+
if (dsrc->getType()->isIntegerTy())
2239+
dsrc = Builder2.CreateIntToPtr(dsrc,
2240+
Type::getInt8PtrTy(dsrc->getContext()));
22352241

22362242
auto call =
22372243
Builder2.CreateMemCpy(ddst, dstAlign, dsrc, srcAlign, new_size);
@@ -6059,59 +6065,21 @@ class AdjointGenerator
60596065
subretType = DIFFE_TYPE::OUT_DIFF;
60606066
}
60616067

6062-
auto found = customCallHandlers.find(funcName.str());
6063-
if (found != customCallHandlers.end()) {
6064-
IRBuilder<> Builder2(call.getParent());
6065-
if (Mode == DerivativeMode::ReverseModeGradient ||
6066-
Mode == DerivativeMode::ReverseModeCombined)
6067-
getReverseBuilder(Builder2);
6068-
6069-
Value *invertedReturn = nullptr;
6070-
bool hasNonReturnUse = false;
6071-
auto ifound = gutils->invertedPointers.find(orig);
6072-
if (ifound != gutils->invertedPointers.end()) {
6073-
//! We only need the shadow pointer for non-forward Mode if it is used
6074-
//! in a non return setting
6075-
hasNonReturnUse = subretType == DIFFE_TYPE::DUP_ARG;
6076-
if (hasNonReturnUse)
6068+
if (Mode == DerivativeMode::ForwardMode) {
6069+
auto found = customFwdCallHandlers.find(funcName.str());
6070+
if (found != customFwdCallHandlers.end()) {
6071+
Value *invertedReturn = nullptr;
6072+
auto ifound = gutils->invertedPointers.find(orig);
6073+
if (ifound != gutils->invertedPointers.end()) {
60776074
invertedReturn = cast<PHINode>(&*ifound->second);
6078-
}
6075+
}
60796076

6080-
Value *normalReturn = subretused ? newCall : nullptr;
6077+
Value *normalReturn = subretused ? newCall : nullptr;
60816078

6082-
Value *tape = nullptr;
6079+
found->second(BuilderZ, orig, *gutils, normalReturn, invertedReturn);
60836080

6084-
if (Mode == DerivativeMode::ReverseModePrimal ||
6085-
Mode == DerivativeMode::ReverseModeCombined) {
6086-
found->second.first(BuilderZ, orig, *gutils, normalReturn,
6087-
invertedReturn, tape);
6088-
if (tape)
6089-
gutils->cacheForReverse(BuilderZ, tape,
6090-
getIndex(orig, CacheType::Tape));
6091-
}
6092-
6093-
if (Mode == DerivativeMode::ReverseModeGradient ||
6094-
Mode == DerivativeMode::ReverseModeCombined) {
6095-
if (Mode == DerivativeMode::ReverseModeGradient &&
6096-
augmentedReturn->tapeIndices.find(std::make_pair(
6097-
orig, CacheType::Tape)) != augmentedReturn->tapeIndices.end()) {
6098-
tape = BuilderZ.CreatePHI(Type::getInt32Ty(orig->getContext()), 0);
6099-
tape = gutils->cacheForReverse(BuilderZ, tape,
6100-
getIndex(orig, CacheType::Tape),
6101-
/*ignoreType*/ true);
6102-
}
6103-
if (tape)
6104-
tape = gutils->lookupM(tape, Builder2);
6105-
found->second.second(Builder2, orig, *(DiffeGradientUtils *)gutils,
6106-
tape);
6107-
}
6108-
6109-
if (ifound != gutils->invertedPointers.end()) {
6110-
auto placeholder = cast<PHINode>(&*ifound->second);
6111-
if (!hasNonReturnUse) {
6112-
gutils->invertedPointers.erase(ifound);
6113-
gutils->erase(placeholder);
6114-
} else {
6081+
if (ifound != gutils->invertedPointers.end()) {
6082+
auto placeholder = cast<PHINode>(&*ifound->second);
61156083
if (invertedReturn && invertedReturn != placeholder) {
61166084
if (invertedReturn->getType() != orig->getType()) {
61176085
llvm::errs() << " o: " << *orig << "\n";
@@ -6126,50 +6094,143 @@ class AdjointGenerator
61266094
assert(invertedReturn->getType() == orig->getType());
61276095
placeholder->replaceAllUsesWith(invertedReturn);
61286096
gutils->erase(placeholder);
6129-
} else
6130-
invertedReturn = placeholder;
6131-
6132-
invertedReturn = gutils->cacheForReverse(
6133-
BuilderZ, invertedReturn, getIndex(orig, CacheType::Shadow));
6134-
6135-
gutils->invertedPointers.insert(std::make_pair(
6136-
(const Value *)orig, InvertedPointerVH(gutils, invertedReturn)));
6097+
gutils->invertedPointers.insert(
6098+
std::make_pair((const Value *)orig,
6099+
InvertedPointerVH(gutils, invertedReturn)));
6100+
} else {
6101+
gutils->invertedPointers.erase(orig);
6102+
gutils->erase(placeholder);
6103+
}
61376104
}
6138-
}
6139-
6140-
bool primalNeededInReverse;
61416105

6142-
if (gutils->knownRecomputeHeuristic.count(orig)) {
6143-
primalNeededInReverse = !gutils->knownRecomputeHeuristic[orig];
6144-
} else {
6145-
std::map<UsageKey, bool> Seen;
6146-
for (auto pair : gutils->knownRecomputeHeuristic)
6147-
if (!pair.second)
6148-
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
6149-
primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
6150-
TR, gutils, orig, Mode, Seen, oldUnreachable);
6151-
}
6152-
if (subretused && primalNeededInReverse) {
6153-
if (normalReturn != newCall) {
6154-
assert(normalReturn->getType() == newCall->getType());
6155-
gutils->replaceAWithB(newCall, normalReturn);
6156-
BuilderZ.SetInsertPoint(newCall->getNextNode());
6157-
gutils->erase(newCall);
6158-
}
6159-
normalReturn = gutils->cacheForReverse(BuilderZ, normalReturn,
6160-
getIndex(orig, CacheType::Self));
6161-
} else {
61626106
if (normalReturn && normalReturn != newCall) {
61636107
assert(normalReturn->getType() == newCall->getType());
61646108
assert(Mode != DerivativeMode::ReverseModeGradient);
61656109
gutils->replaceAWithB(newCall, normalReturn);
6166-
BuilderZ.SetInsertPoint(newCall->getNextNode());
61676110
gutils->erase(newCall);
6168-
} else if (!orig->mayWriteToMemory() ||
6169-
Mode == DerivativeMode::ReverseModeGradient)
6170-
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
6111+
}
6112+
eraseIfUnused(*orig);
6113+
return;
6114+
}
6115+
}
6116+
6117+
if (Mode == DerivativeMode::ReverseModePrimal ||
6118+
Mode == DerivativeMode::ReverseModeCombined ||
6119+
Mode == DerivativeMode::ReverseModeGradient) {
6120+
auto found = customCallHandlers.find(funcName.str());
6121+
if (found != customCallHandlers.end()) {
6122+
IRBuilder<> Builder2(call.getParent());
6123+
if (Mode == DerivativeMode::ReverseModeGradient ||
6124+
Mode == DerivativeMode::ReverseModeCombined)
6125+
getReverseBuilder(Builder2);
6126+
6127+
Value *invertedReturn = nullptr;
6128+
bool hasNonReturnUse = false;
6129+
auto ifound = gutils->invertedPointers.find(orig);
6130+
if (ifound != gutils->invertedPointers.end()) {
6131+
//! We only need the shadow pointer for non-forward Mode if it is used
6132+
//! in a non return setting
6133+
hasNonReturnUse = subretType == DIFFE_TYPE::DUP_ARG;
6134+
if (hasNonReturnUse)
6135+
invertedReturn = cast<PHINode>(&*ifound->second);
6136+
}
6137+
6138+
Value *normalReturn = subretused ? newCall : nullptr;
6139+
6140+
Value *tape = nullptr;
6141+
6142+
if (Mode == DerivativeMode::ReverseModePrimal ||
6143+
Mode == DerivativeMode::ReverseModeCombined) {
6144+
found->second.first(BuilderZ, orig, *gutils, normalReturn,
6145+
invertedReturn, tape);
6146+
if (tape)
6147+
gutils->cacheForReverse(BuilderZ, tape,
6148+
getIndex(orig, CacheType::Tape));
6149+
}
6150+
6151+
if (Mode == DerivativeMode::ReverseModeGradient ||
6152+
Mode == DerivativeMode::ReverseModeCombined) {
6153+
if (Mode == DerivativeMode::ReverseModeGradient &&
6154+
augmentedReturn->tapeIndices.find(
6155+
std::make_pair(orig, CacheType::Tape)) !=
6156+
augmentedReturn->tapeIndices.end()) {
6157+
tape = BuilderZ.CreatePHI(Type::getInt32Ty(orig->getContext()), 0);
6158+
tape = gutils->cacheForReverse(BuilderZ, tape,
6159+
getIndex(orig, CacheType::Tape),
6160+
/*ignoreType*/ true);
6161+
}
6162+
if (tape)
6163+
tape = gutils->lookupM(tape, Builder2);
6164+
found->second.second(Builder2, orig, *(DiffeGradientUtils *)gutils,
6165+
tape);
6166+
}
6167+
6168+
if (ifound != gutils->invertedPointers.end()) {
6169+
auto placeholder = cast<PHINode>(&*ifound->second);
6170+
if (!hasNonReturnUse) {
6171+
gutils->invertedPointers.erase(ifound);
6172+
gutils->erase(placeholder);
6173+
} else {
6174+
if (invertedReturn && invertedReturn != placeholder) {
6175+
if (invertedReturn->getType() != orig->getType()) {
6176+
llvm::errs() << " o: " << *orig << "\n";
6177+
llvm::errs() << " ot: " << *orig->getType() << "\n";
6178+
llvm::errs() << " ir: " << *invertedReturn << "\n";
6179+
llvm::errs() << " irt: " << *invertedReturn->getType() << "\n";
6180+
llvm::errs() << " p: " << *placeholder << "\n";
6181+
llvm::errs() << " PT: " << *placeholder->getType() << "\n";
6182+
llvm::errs() << " newCall: " << *newCall << "\n";
6183+
llvm::errs() << " newCallT: " << *newCall->getType() << "\n";
6184+
}
6185+
assert(invertedReturn->getType() == orig->getType());
6186+
placeholder->replaceAllUsesWith(invertedReturn);
6187+
gutils->erase(placeholder);
6188+
} else
6189+
invertedReturn = placeholder;
6190+
6191+
invertedReturn = gutils->cacheForReverse(
6192+
BuilderZ, invertedReturn, getIndex(orig, CacheType::Shadow));
6193+
6194+
gutils->invertedPointers.insert(
6195+
std::make_pair((const Value *)orig,
6196+
InvertedPointerVH(gutils, invertedReturn)));
6197+
}
6198+
}
6199+
6200+
bool primalNeededInReverse;
6201+
6202+
if (gutils->knownRecomputeHeuristic.count(orig)) {
6203+
primalNeededInReverse = !gutils->knownRecomputeHeuristic[orig];
6204+
} else {
6205+
std::map<UsageKey, bool> Seen;
6206+
for (auto pair : gutils->knownRecomputeHeuristic)
6207+
if (!pair.second)
6208+
Seen[UsageKey(pair.first, ValueType::Primal)] = false;
6209+
primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
6210+
TR, gutils, orig, Mode, Seen, oldUnreachable);
6211+
}
6212+
if (subretused && primalNeededInReverse) {
6213+
if (normalReturn != newCall) {
6214+
assert(normalReturn->getType() == newCall->getType());
6215+
gutils->replaceAWithB(newCall, normalReturn);
6216+
BuilderZ.SetInsertPoint(newCall->getNextNode());
6217+
gutils->erase(newCall);
6218+
}
6219+
normalReturn = gutils->cacheForReverse(
6220+
BuilderZ, normalReturn, getIndex(orig, CacheType::Self));
6221+
} else {
6222+
if (normalReturn && normalReturn != newCall) {
6223+
assert(normalReturn->getType() == newCall->getType());
6224+
assert(Mode != DerivativeMode::ReverseModeGradient);
6225+
gutils->replaceAWithB(newCall, normalReturn);
6226+
BuilderZ.SetInsertPoint(newCall->getNextNode());
6227+
gutils->erase(newCall);
6228+
} else if (!orig->mayWriteToMemory() ||
6229+
Mode == DerivativeMode::ReverseModeGradient)
6230+
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
6231+
}
6232+
return;
61716233
}
6172-
return;
61736234
}
61746235

61756236
if (Mode != DerivativeMode::ReverseModePrimal && called) {
@@ -7875,6 +7936,9 @@ class AdjointGenerator
78757936
argsInverted.push_back(DIFFE_TYPE::DUP_ARG);
78767937
}
78777938
}
7939+
if (!called)
7940+
llvm::errs() << *called << "\n";
7941+
assert(called);
78787942

78797943
auto newcalled = gutils->Logic.CreateForwardDiff(
78807944
cast<Function>(called), subretType, argsInverted, gutils->TLI,

enzyme/Enzyme/CApi.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ void EnzymeRegisterFunctionHandler(char *Name, CustomShadowAlloc AHandle,
258258
};
259259
}
260260

261-
void EnzymeRegisterCallHandler(char *Name, CustomFunctionForward FwdHandle,
261+
void EnzymeRegisterCallHandler(char *Name,
262+
CustomAugmentedFunctionForward FwdHandle,
262263
CustomFunctionReverse RevHandle) {
263264
auto &pair = customCallHandlers[std::string(Name)];
264265
pair.first = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,
@@ -277,6 +278,18 @@ void EnzymeRegisterCallHandler(char *Name, CustomFunctionForward FwdHandle,
277278
};
278279
}
279280

281+
void EnzymeRegisterFwdCallHandler(char *Name, CustomFunctionForward FwdHandle) {
282+
auto &pair = customFwdCallHandlers[std::string(Name)];
283+
pair = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,
284+
Value *&normalReturn, Value *&shadowReturn) {
285+
LLVMValueRef normalR = wrap(normalReturn);
286+
LLVMValueRef shadowR = wrap(shadowReturn);
287+
FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR);
288+
normalReturn = unwrap(normalR);
289+
shadowReturn = unwrap(shadowR);
290+
};
291+
}
292+
280293
LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils,
281294
LLVMValueRef val) {
282295
return wrap(gutils->getNewFromOriginal(unwrap(val)));

enzyme/Enzyme/CApi.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,11 @@ class DiffeGradientUtils;
175175

176176
typedef void (*CustomFunctionForward)(LLVMBuilderRef, LLVMValueRef,
177177
GradientUtils *, LLVMValueRef *,
178-
LLVMValueRef *, LLVMValueRef *);
178+
LLVMValueRef *);
179+
180+
typedef void (*CustomAugmentedFunctionForward)(LLVMBuilderRef, LLVMValueRef,
181+
GradientUtils *, LLVMValueRef *,
182+
LLVMValueRef *, LLVMValueRef *);
179183

180184
typedef void (*CustomFunctionReverse)(LLVMBuilderRef, LLVMValueRef,
181185
DiffeGradientUtils *, LLVMValueRef);

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3757,8 +3757,6 @@ Function *EnzymeLogic::CreateForwardDiff(
37573757
}
37583758
}
37593759

3760-
assert(!todiff->empty());
3761-
37623760
if (hasMetadata(todiff, "enzyme_derivative") && !hasconstant) {
37633761
auto md = todiff->getMetadata("enzyme_derivative");
37643762
if (!isa<MDTuple>(md)) {
@@ -3774,6 +3772,9 @@ Function *EnzymeLogic::CreateForwardDiff(
37743772

37753773
return foundcalled;
37763774
}
3775+
if (todiff->empty())
3776+
llvm::errs() << *todiff << "\n";
3777+
assert(!todiff->empty());
37773778

37783779
bool retActive = retType != DIFFE_TYPE::CONSTANT;
37793780

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ std::map<
5959
DiffeGradientUtils &, Value *)>>>
6060
customCallHandlers;
6161

62+
std::map<std::string, std::function<void(IRBuilder<> &, CallInst *,
63+
GradientUtils &, Value *&, Value *&)>>
64+
customFwdCallHandlers;
65+
6266
extern "C" {
6367
llvm::cl::opt<bool>
6468
EnzymeNewCache("enzyme-new-cache", cl::init(true), cl::Hidden,

enzyme/Enzyme/GradientUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ extern std::map<
8888
DiffeGradientUtils &, llvm::Value *)>>>
8989
customCallHandlers;
9090

91+
extern std::map<
92+
std::string,
93+
std::function<void(llvm::IRBuilder<> &, llvm::CallInst *, GradientUtils &,
94+
llvm::Value *&, llvm::Value *&)>>
95+
customFwdCallHandlers;
96+
9197
extern "C" {
9298
extern llvm::cl::opt<bool> EnzymeInactiveDynamic;
9399
extern llvm::cl::opt<bool> EnzymeFreeInternalAllocations;

enzyme/test/Enzyme/ForwardMode/calloc.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -simplifycfg -S | FileCheck %s
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
22

33

44
@enzyme_dupnoneed = dso_local global i32 0, align 4
@@ -37,4 +37,4 @@ declare dso_local double @__enzyme_fwddiff(i8*, ...)
3737
; CHECK-NEXT: store double %"x'", double* %"'ipc", align 8
3838
; CHECK-NEXT: %2 = load double, double* %"'ipc", align 8
3939
; CHECK-NEXT: ret double %2
40-
; CHECK-NEXT: }
40+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)