Skip to content

Commit 235fc1d

Browse files
authored
Keep integer extract (rust-lang#914)
* Keep integer extract * Fixup
1 parent 2a8ba2d commit 235fc1d

File tree

4 files changed

+69
-52
lines changed

4 files changed

+69
-52
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,52 @@ class AdjointGenerator
11321132
bool constantval = gutils->isConstantValue(orig_val) ||
11331133
parseTBAA(I, DL).Inner0().isIntegral();
11341134

1135+
// TODO allow recognition of other types that could contain pointers [e.g.
1136+
// {void*, void*} or <2 x i64> ]
1137+
auto storeSize = DL.getTypeSizeInBits(valType) / 8;
1138+
1139+
auto vd = TR.query(orig_ptr).Lookup(storeSize, DL);
1140+
1141+
if (!vd.isKnown()) {
1142+
if (looseTypeAnalysis || true) {
1143+
vd = defaultTypeTreeForLLVM(valType, &I);
1144+
EmitWarning("CannotDeduceType", I, "failed to deduce type of xtore ",
1145+
I);
1146+
goto known;
1147+
}
1148+
if (CustomErrorHandler) {
1149+
std::string str;
1150+
raw_string_ostream ss(str);
1151+
ss << "Cannot deduce type of store " << I;
1152+
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
1153+
&TR.analyzer);
1154+
}
1155+
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I,
1156+
"failed to deduce type of store ", I);
1157+
1158+
TR.intType(storeSize, orig_ptr, /*errifnotfound*/ true,
1159+
/*pointerIntSame*/ true);
1160+
llvm_unreachable("bad mti");
1161+
known:;
1162+
}
1163+
1164+
auto dt = vd[{-1}];
1165+
for (size_t i = 0; i < storeSize; ++i) {
1166+
bool Legal = true;
1167+
dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
1168+
if (!Legal) {
1169+
if (CustomErrorHandler) {
1170+
std::string str;
1171+
raw_string_ostream ss(str);
1172+
ss << "Cannot deduce single type of store " << I;
1173+
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
1174+
&TR.analyzer);
1175+
}
1176+
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I,
1177+
"failed to deduce single type of store ", I);
1178+
}
1179+
}
1180+
11351181
if (Mode == DerivativeMode::ForwardMode) {
11361182
IRBuilder<> Builder2(&I);
11371183
getForwardBuilder(Builder2);
@@ -1140,7 +1186,8 @@ class AdjointGenerator
11401186
// TODO type analyze
11411187
if (!constantval)
11421188
diff = gutils->invertPointerM(orig_val, Builder2, /*nullShadow*/ true);
1143-
else if (orig_val->getType()->isPointerTy())
1189+
else if (orig_val->getType()->isPointerTy() || dt == BaseType::Pointer ||
1190+
dt == BaseType::Integer)
11441191
diff = gutils->invertPointerM(orig_val, Builder2, /*nullShadow*/ false);
11451192
else
11461193
diff = gutils->invertPointerM(orig_val, Builder2, /*nullShadow*/ true);
@@ -1150,41 +1197,8 @@ class AdjointGenerator
11501197
return;
11511198
}
11521199

1153-
// TODO allow recognition of other types that could contain pointers [e.g.
1154-
// {void*, void*} or <2 x i64> ]
1155-
auto storeSize = DL.getTypeSizeInBits(valType) / 8;
1156-
11571200
//! Storing a floating point value
1158-
Type *FT = nullptr;
1159-
if (valType->isFPOrFPVectorTy()) {
1160-
FT = valType->getScalarType();
1161-
} else if (!valType->isPointerTy()) {
1162-
auto fp =
1163-
TR.firstPointer(storeSize, orig_ptr, &I, /*errifnotfound*/ false,
1164-
/*pointerIntSame*/ true);
1165-
if (fp.isKnown()) {
1166-
FT = fp.isFloat();
1167-
} else if (looseTypeAnalysis && (isa<ConstantInt>(orig_val) ||
1168-
valType->isIntOrIntVectorTy())) {
1169-
llvm::errs() << "assuming type as integral for store: " << I << "\n";
1170-
FT = nullptr;
1171-
} else {
1172-
1173-
if (CustomErrorHandler) {
1174-
std::string str;
1175-
raw_string_ostream ss(str);
1176-
ss << "Cannot deduce type of store " << I;
1177-
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
1178-
&TR.analyzer);
1179-
}
1180-
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I,
1181-
"failed to deduce type of store ", I);
1182-
TR.firstPointer(storeSize, orig_ptr, &I, /*errifnotfound*/ true,
1183-
/*pointerIntSame*/ true);
1184-
}
1185-
}
1186-
1187-
if (FT) {
1201+
if (Type *FT = dt.isFloat()) {
11881202
//! Only need to update the reverse function
11891203
switch (Mode) {
11901204
case DerivativeMode::ReverseModePrimal:

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,10 @@ static inline bool is_value_needed_in_reverse(
722722
}
723723

724724
if (isa<ReturnInst>(user)) {
725-
if (gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_ARG ||
726-
gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_NONEED) {
725+
if ((gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_ARG ||
726+
gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_NONEED) &&
727+
((inst_cv && VT == ValueType::Primal) ||
728+
(!inst_cv && VT == ValueType::Shadow))) {
727729
if (EnzymePrintDiffUse)
728730
llvm::errs() << " Need: " << to_string(VT) << " of " << *inst
729731
<< " in reverse as shadow return " << *user << "\n";
@@ -755,10 +757,10 @@ static inline bool is_value_needed_in_reverse(
755757
if (user->getType()->isVoidTy())
756758
goto endShadow;
757759

758-
if (!TR.query(const_cast<Instruction *>(user))
759-
.Inner0()
760-
.isPossiblePointer())
760+
if (!TR.query(const_cast<Instruction *>(user))[{-1}]
761+
.isPossiblePointer()) {
761762
goto endShadow;
763+
}
762764

763765
if (!OneLevel && is_value_needed_in_reverse<ValueType::Shadow>(
764766
gutils, user, mode, seen, oldUnreachable)) {
@@ -884,14 +886,20 @@ static inline bool is_value_needed_in_reverse(
884886
bool valueIsIndex = false;
885887
for (unsigned i = 2; i < IVI->getNumOperands(); ++i) {
886888
if (IVI->getOperand(i) == inst) {
889+
if (inst == IVI->getInsertedValueOperand() &&
890+
TR.query(
891+
const_cast<Value *>(IVI->getInsertedValueOperand()))[{-1}]
892+
.isFloat()) {
893+
continue;
894+
}
887895
valueIsIndex = true;
888896
}
889897
}
890898
primalUsedInShadowPointer = valueIsIndex;
891899
}
892900
if (auto EVI = dyn_cast<ExtractValueInst>(user)) {
893901
bool valueIsIndex = false;
894-
for (unsigned i = 2; i < EVI->getNumOperands(); ++i) {
902+
for (unsigned i = 1; i < EVI->getNumOperands(); ++i) {
895903
if (EVI->getOperand(i) == inst) {
896904
valueIsIndex = true;
897905
}

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,11 +2828,7 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
28282828
} else if (!gutils->isConstantValue(ret)) {
28292829
toret = gutils->diffe(ret, nBuilder);
28302830
} else {
2831-
IRBuilder<> eB(gutils->inversionAllocs);
2832-
Type *retTy = gutils->getShadowType(ret->getType());
2833-
auto al = eB.CreateAlloca(retTy);
2834-
ZeroMemory(eB, retTy, al, /*isTape*/ false);
2835-
toret = nBuilder.CreateLoad(al);
2831+
toret = gutils->invertPointerM(ret, nBuilder, /*nullInit*/ true);
28362832
}
28372833

28382834
break;
@@ -2853,9 +2849,8 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
28532849
toret =
28542850
nBuilder.CreateInsertValue(toret, gutils->diffe(ret, nBuilder), 1);
28552851
} else {
2856-
Type *retTy = gutils->getShadowType(ret->getType());
2857-
toret =
2858-
nBuilder.CreateInsertValue(toret, Constant::getNullValue(retTy), 1);
2852+
toret = nBuilder.CreateInsertValue(
2853+
toret, gutils->invertPointerM(ret, nBuilder, /*nullInit*/ true), 1);
28592854
}
28602855
break;
28612856
}

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4060,7 +4060,8 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
40604060
return applyChainRule(oval->getType(), BuilderM, rule);
40614061
}
40624062

4063-
if (isConstantValue(oval)) {
4063+
if (isConstantValue(oval) && !isa<InsertValueInst>(oval) &&
4064+
!isa<ExtractValueInst>(oval)) {
40644065
// NOTE, this is legal and the correct resolution, however, our activity
40654066
// analysis honeypot no longer exists
40664067

@@ -4084,7 +4085,6 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
40844085

40854086
return applyChainRule(oval->getType(), BuilderM, rule);
40864087
}
4087-
assert(!isConstantValue(oval));
40884088

40894089
auto M = oldFunc->getParent();
40904090
assert(oval);
@@ -4477,7 +4477,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
44774477
goto end;
44784478
} else if (auto arg = dyn_cast<ExtractValueInst>(oval)) {
44794479
IRBuilder<> bb(getNewFromOriginal(arg));
4480-
auto ip = invertPointerM(arg->getOperand(0), bb);
4480+
auto ip = invertPointerM(arg->getOperand(0), bb, nullShadow);
44814481

44824482
auto rule = [&bb, &arg](Value *ip) {
44834483
return bb.CreateExtractValue(ip, arg->getIndices(),

0 commit comments

Comments
 (0)