@@ -2231,7 +2231,13 @@ class AdjointGenerator
2231
2231
IRBuilder<> Builder2(&MTI);
2232
2232
getForwardBuilder(Builder2);
2233
2233
auto ddst = gutils->invertPointerM(orig_dst, Builder2);
2234
+ if (ddst->getType()->isIntegerTy())
2235
+ ddst = Builder2.CreateIntToPtr(ddst,
2236
+ Type::getInt8PtrTy(ddst->getContext()));
2234
2237
auto dsrc = gutils->invertPointerM(orig_src, Builder2);
2238
+ if (dsrc->getType()->isIntegerTy())
2239
+ dsrc = Builder2.CreateIntToPtr(dsrc,
2240
+ Type::getInt8PtrTy(dsrc->getContext()));
2235
2241
2236
2242
auto call =
2237
2243
Builder2.CreateMemCpy(ddst, dstAlign, dsrc, srcAlign, new_size);
@@ -6059,59 +6065,21 @@ class AdjointGenerator
6059
6065
subretType = DIFFE_TYPE::OUT_DIFF;
6060
6066
}
6061
6067
6062
- auto found = customCallHandlers.find(funcName.str());
6063
- if (found != customCallHandlers.end()) {
6064
- IRBuilder<> Builder2(call.getParent());
6065
- if (Mode == DerivativeMode::ReverseModeGradient ||
6066
- Mode == DerivativeMode::ReverseModeCombined)
6067
- getReverseBuilder(Builder2);
6068
-
6069
- Value *invertedReturn = nullptr;
6070
- bool hasNonReturnUse = false;
6071
- auto ifound = gutils->invertedPointers.find(orig);
6072
- if (ifound != gutils->invertedPointers.end()) {
6073
- //! We only need the shadow pointer for non-forward Mode if it is used
6074
- //! in a non return setting
6075
- hasNonReturnUse = subretType == DIFFE_TYPE::DUP_ARG;
6076
- if (hasNonReturnUse)
6068
+ if (Mode == DerivativeMode::ForwardMode) {
6069
+ auto found = customFwdCallHandlers.find(funcName.str());
6070
+ if (found != customFwdCallHandlers.end()) {
6071
+ Value *invertedReturn = nullptr;
6072
+ auto ifound = gutils->invertedPointers.find(orig);
6073
+ if (ifound != gutils->invertedPointers.end()) {
6077
6074
invertedReturn = cast<PHINode>(&*ifound->second);
6078
- }
6075
+ }
6079
6076
6080
- Value *normalReturn = subretused ? newCall : nullptr;
6077
+ Value *normalReturn = subretused ? newCall : nullptr;
6081
6078
6082
- Value *tape = nullptr ;
6079
+ found->second(BuilderZ, orig, *gutils, normalReturn, invertedReturn) ;
6083
6080
6084
- if (Mode == DerivativeMode::ReverseModePrimal ||
6085
- Mode == DerivativeMode::ReverseModeCombined) {
6086
- found->second.first(BuilderZ, orig, *gutils, normalReturn,
6087
- invertedReturn, tape);
6088
- if (tape)
6089
- gutils->cacheForReverse(BuilderZ, tape,
6090
- getIndex(orig, CacheType::Tape));
6091
- }
6092
-
6093
- if (Mode == DerivativeMode::ReverseModeGradient ||
6094
- Mode == DerivativeMode::ReverseModeCombined) {
6095
- if (Mode == DerivativeMode::ReverseModeGradient &&
6096
- augmentedReturn->tapeIndices.find(std::make_pair(
6097
- orig, CacheType::Tape)) != augmentedReturn->tapeIndices.end()) {
6098
- tape = BuilderZ.CreatePHI(Type::getInt32Ty(orig->getContext()), 0);
6099
- tape = gutils->cacheForReverse(BuilderZ, tape,
6100
- getIndex(orig, CacheType::Tape),
6101
- /*ignoreType*/ true);
6102
- }
6103
- if (tape)
6104
- tape = gutils->lookupM(tape, Builder2);
6105
- found->second.second(Builder2, orig, *(DiffeGradientUtils *)gutils,
6106
- tape);
6107
- }
6108
-
6109
- if (ifound != gutils->invertedPointers.end()) {
6110
- auto placeholder = cast<PHINode>(&*ifound->second);
6111
- if (!hasNonReturnUse) {
6112
- gutils->invertedPointers.erase(ifound);
6113
- gutils->erase(placeholder);
6114
- } else {
6081
+ if (ifound != gutils->invertedPointers.end()) {
6082
+ auto placeholder = cast<PHINode>(&*ifound->second);
6115
6083
if (invertedReturn && invertedReturn != placeholder) {
6116
6084
if (invertedReturn->getType() != orig->getType()) {
6117
6085
llvm::errs() << " o: " << *orig << "\n";
@@ -6126,50 +6094,143 @@ class AdjointGenerator
6126
6094
assert(invertedReturn->getType() == orig->getType());
6127
6095
placeholder->replaceAllUsesWith(invertedReturn);
6128
6096
gutils->erase(placeholder);
6129
- } else
6130
- invertedReturn = placeholder;
6131
-
6132
- invertedReturn = gutils->cacheForReverse(
6133
- BuilderZ, invertedReturn, getIndex(orig, CacheType::Shadow));
6134
-
6135
- gutils->invertedPointers.insert(std::make_pair(
6136
- (const Value *)orig, InvertedPointerVH(gutils, invertedReturn)));
6097
+ gutils->invertedPointers.insert(
6098
+ std::make_pair((const Value *)orig,
6099
+ InvertedPointerVH(gutils, invertedReturn)));
6100
+ } else {
6101
+ gutils->invertedPointers.erase(orig);
6102
+ gutils->erase(placeholder);
6103
+ }
6137
6104
}
6138
- }
6139
-
6140
- bool primalNeededInReverse;
6141
6105
6142
- if (gutils->knownRecomputeHeuristic.count(orig)) {
6143
- primalNeededInReverse = !gutils->knownRecomputeHeuristic[orig];
6144
- } else {
6145
- std::map<UsageKey, bool> Seen;
6146
- for (auto pair : gutils->knownRecomputeHeuristic)
6147
- if (!pair.second)
6148
- Seen[UsageKey(pair.first, ValueType::Primal)] = false;
6149
- primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
6150
- TR, gutils, orig, Mode, Seen, oldUnreachable);
6151
- }
6152
- if (subretused && primalNeededInReverse) {
6153
- if (normalReturn != newCall) {
6154
- assert(normalReturn->getType() == newCall->getType());
6155
- gutils->replaceAWithB(newCall, normalReturn);
6156
- BuilderZ.SetInsertPoint(newCall->getNextNode());
6157
- gutils->erase(newCall);
6158
- }
6159
- normalReturn = gutils->cacheForReverse(BuilderZ, normalReturn,
6160
- getIndex(orig, CacheType::Self));
6161
- } else {
6162
6106
if (normalReturn && normalReturn != newCall) {
6163
6107
assert(normalReturn->getType() == newCall->getType());
6164
6108
assert(Mode != DerivativeMode::ReverseModeGradient);
6165
6109
gutils->replaceAWithB(newCall, normalReturn);
6166
- BuilderZ.SetInsertPoint(newCall->getNextNode());
6167
6110
gutils->erase(newCall);
6168
- } else if (!orig->mayWriteToMemory() ||
6169
- Mode == DerivativeMode::ReverseModeGradient)
6170
- eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
6111
+ }
6112
+ eraseIfUnused(*orig);
6113
+ return;
6114
+ }
6115
+ }
6116
+
6117
+ if (Mode == DerivativeMode::ReverseModePrimal ||
6118
+ Mode == DerivativeMode::ReverseModeCombined ||
6119
+ Mode == DerivativeMode::ReverseModeGradient) {
6120
+ auto found = customCallHandlers.find(funcName.str());
6121
+ if (found != customCallHandlers.end()) {
6122
+ IRBuilder<> Builder2(call.getParent());
6123
+ if (Mode == DerivativeMode::ReverseModeGradient ||
6124
+ Mode == DerivativeMode::ReverseModeCombined)
6125
+ getReverseBuilder(Builder2);
6126
+
6127
+ Value *invertedReturn = nullptr;
6128
+ bool hasNonReturnUse = false;
6129
+ auto ifound = gutils->invertedPointers.find(orig);
6130
+ if (ifound != gutils->invertedPointers.end()) {
6131
+ //! We only need the shadow pointer for non-forward Mode if it is used
6132
+ //! in a non return setting
6133
+ hasNonReturnUse = subretType == DIFFE_TYPE::DUP_ARG;
6134
+ if (hasNonReturnUse)
6135
+ invertedReturn = cast<PHINode>(&*ifound->second);
6136
+ }
6137
+
6138
+ Value *normalReturn = subretused ? newCall : nullptr;
6139
+
6140
+ Value *tape = nullptr;
6141
+
6142
+ if (Mode == DerivativeMode::ReverseModePrimal ||
6143
+ Mode == DerivativeMode::ReverseModeCombined) {
6144
+ found->second.first(BuilderZ, orig, *gutils, normalReturn,
6145
+ invertedReturn, tape);
6146
+ if (tape)
6147
+ gutils->cacheForReverse(BuilderZ, tape,
6148
+ getIndex(orig, CacheType::Tape));
6149
+ }
6150
+
6151
+ if (Mode == DerivativeMode::ReverseModeGradient ||
6152
+ Mode == DerivativeMode::ReverseModeCombined) {
6153
+ if (Mode == DerivativeMode::ReverseModeGradient &&
6154
+ augmentedReturn->tapeIndices.find(
6155
+ std::make_pair(orig, CacheType::Tape)) !=
6156
+ augmentedReturn->tapeIndices.end()) {
6157
+ tape = BuilderZ.CreatePHI(Type::getInt32Ty(orig->getContext()), 0);
6158
+ tape = gutils->cacheForReverse(BuilderZ, tape,
6159
+ getIndex(orig, CacheType::Tape),
6160
+ /*ignoreType*/ true);
6161
+ }
6162
+ if (tape)
6163
+ tape = gutils->lookupM(tape, Builder2);
6164
+ found->second.second(Builder2, orig, *(DiffeGradientUtils *)gutils,
6165
+ tape);
6166
+ }
6167
+
6168
+ if (ifound != gutils->invertedPointers.end()) {
6169
+ auto placeholder = cast<PHINode>(&*ifound->second);
6170
+ if (!hasNonReturnUse) {
6171
+ gutils->invertedPointers.erase(ifound);
6172
+ gutils->erase(placeholder);
6173
+ } else {
6174
+ if (invertedReturn && invertedReturn != placeholder) {
6175
+ if (invertedReturn->getType() != orig->getType()) {
6176
+ llvm::errs() << " o: " << *orig << "\n";
6177
+ llvm::errs() << " ot: " << *orig->getType() << "\n";
6178
+ llvm::errs() << " ir: " << *invertedReturn << "\n";
6179
+ llvm::errs() << " irt: " << *invertedReturn->getType() << "\n";
6180
+ llvm::errs() << " p: " << *placeholder << "\n";
6181
+ llvm::errs() << " PT: " << *placeholder->getType() << "\n";
6182
+ llvm::errs() << " newCall: " << *newCall << "\n";
6183
+ llvm::errs() << " newCallT: " << *newCall->getType() << "\n";
6184
+ }
6185
+ assert(invertedReturn->getType() == orig->getType());
6186
+ placeholder->replaceAllUsesWith(invertedReturn);
6187
+ gutils->erase(placeholder);
6188
+ } else
6189
+ invertedReturn = placeholder;
6190
+
6191
+ invertedReturn = gutils->cacheForReverse(
6192
+ BuilderZ, invertedReturn, getIndex(orig, CacheType::Shadow));
6193
+
6194
+ gutils->invertedPointers.insert(
6195
+ std::make_pair((const Value *)orig,
6196
+ InvertedPointerVH(gutils, invertedReturn)));
6197
+ }
6198
+ }
6199
+
6200
+ bool primalNeededInReverse;
6201
+
6202
+ if (gutils->knownRecomputeHeuristic.count(orig)) {
6203
+ primalNeededInReverse = !gutils->knownRecomputeHeuristic[orig];
6204
+ } else {
6205
+ std::map<UsageKey, bool> Seen;
6206
+ for (auto pair : gutils->knownRecomputeHeuristic)
6207
+ if (!pair.second)
6208
+ Seen[UsageKey(pair.first, ValueType::Primal)] = false;
6209
+ primalNeededInReverse = is_value_needed_in_reverse<ValueType::Primal>(
6210
+ TR, gutils, orig, Mode, Seen, oldUnreachable);
6211
+ }
6212
+ if (subretused && primalNeededInReverse) {
6213
+ if (normalReturn != newCall) {
6214
+ assert(normalReturn->getType() == newCall->getType());
6215
+ gutils->replaceAWithB(newCall, normalReturn);
6216
+ BuilderZ.SetInsertPoint(newCall->getNextNode());
6217
+ gutils->erase(newCall);
6218
+ }
6219
+ normalReturn = gutils->cacheForReverse(
6220
+ BuilderZ, normalReturn, getIndex(orig, CacheType::Self));
6221
+ } else {
6222
+ if (normalReturn && normalReturn != newCall) {
6223
+ assert(normalReturn->getType() == newCall->getType());
6224
+ assert(Mode != DerivativeMode::ReverseModeGradient);
6225
+ gutils->replaceAWithB(newCall, normalReturn);
6226
+ BuilderZ.SetInsertPoint(newCall->getNextNode());
6227
+ gutils->erase(newCall);
6228
+ } else if (!orig->mayWriteToMemory() ||
6229
+ Mode == DerivativeMode::ReverseModeGradient)
6230
+ eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
6231
+ }
6232
+ return;
6171
6233
}
6172
- return;
6173
6234
}
6174
6235
6175
6236
if (Mode != DerivativeMode::ReverseModePrimal && called) {
@@ -7875,6 +7936,9 @@ class AdjointGenerator
7875
7936
argsInverted.push_back(DIFFE_TYPE::DUP_ARG);
7876
7937
}
7877
7938
}
7939
+ if (!called)
7940
+ llvm::errs() << *called << "\n";
7941
+ assert(called);
7878
7942
7879
7943
auto newcalled = gutils->Logic.CreateForwardDiff(
7880
7944
cast<Function>(called), subretType, argsInverted, gutils->TLI,
0 commit comments