@@ -10226,9 +10226,11 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
10226
10226
for (const TreeEntry *TE : ForRemoval)
10227
10227
Set.erase(TE);
10228
10228
}
10229
+ bool NeedToRemapValues = false;
10229
10230
for (auto *It = UsedTEs.begin(); It != UsedTEs.end();) {
10230
10231
if (It->empty()) {
10231
10232
UsedTEs.erase(It);
10233
+ NeedToRemapValues = true;
10232
10234
continue;
10233
10235
}
10234
10236
std::advance(It, 1);
@@ -10237,6 +10239,19 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
10237
10239
Entries.clear();
10238
10240
return std::nullopt;
10239
10241
}
10242
+ // Recalculate the mapping between the values and entries sets.
10243
+ if (NeedToRemapValues) {
10244
+ DenseMap<Value *, int> PrevUsedValuesEntry;
10245
+ PrevUsedValuesEntry.swap(UsedValuesEntry);
10246
+ for (auto [Idx, Set] : enumerate(UsedTEs)) {
10247
+ DenseSet<Value *> Values;
10248
+ for (const TreeEntry *E : Set)
10249
+ Values.insert(E->Scalars.begin(), E->Scalars.end());
10250
+ for (const auto &P : PrevUsedValuesEntry)
10251
+ if (Values.contains(P.first))
10252
+ UsedValuesEntry.try_emplace(P.first, Idx);
10253
+ }
10254
+ }
10240
10255
}
10241
10256
10242
10257
unsigned VF = 0;
@@ -11935,7 +11950,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
11935
11950
Builder.SetCurrentDebugLocation(PH->getDebugLoc());
11936
11951
Value *Vec = vectorizeOperand(E, I, /*PostponedPHIs=*/true);
11937
11952
if (VecTy != Vec->getType()) {
11938
- assert((getOperandEntry(E, I)->State == TreeEntry::NeedToGather ||
11953
+ assert((It != MinBWs.end() ||
11954
+ getOperandEntry(E, I)->State == TreeEntry::NeedToGather ||
11939
11955
MinBWs.contains(getOperandEntry(E, I))) &&
11940
11956
"Expected item in MinBWs.");
11941
11957
Vec = Builder.CreateIntCast(Vec, VecTy, It->second.second);
@@ -12193,7 +12209,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
12193
12209
return E->VectorizedValue;
12194
12210
}
12195
12211
if (L->getType() != R->getType()) {
12196
- assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12212
+ assert((It != MinBWs.end() ||
12213
+ getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12197
12214
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
12198
12215
MinBWs.contains(getOperandEntry(E, 0)) ||
12199
12216
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12232,7 +12249,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
12232
12249
return E->VectorizedValue;
12233
12250
}
12234
12251
if (True->getType() != False->getType()) {
12235
- assert((getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
12252
+ assert((It != MinBWs.end() ||
12253
+ getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
12236
12254
getOperandEntry(E, 2)->State == TreeEntry::NeedToGather ||
12237
12255
MinBWs.contains(getOperandEntry(E, 1)) ||
12238
12256
MinBWs.contains(getOperandEntry(E, 2))) &&
@@ -12302,7 +12320,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
12302
12320
return E->VectorizedValue;
12303
12321
}
12304
12322
if (LHS->getType() != RHS->getType()) {
12305
- assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12323
+ assert((It != MinBWs.end() ||
12324
+ getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12306
12325
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
12307
12326
MinBWs.contains(getOperandEntry(E, 0)) ||
12308
12327
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12540,7 +12559,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
12540
12559
return E->VectorizedValue;
12541
12560
}
12542
12561
if (LHS && RHS && LHS->getType() != RHS->getType()) {
12543
- assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12562
+ assert((It != MinBWs.end() ||
12563
+ getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12544
12564
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
12545
12565
MinBWs.contains(getOperandEntry(E, 0)) ||
12546
12566
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -14002,6 +14022,33 @@ bool BoUpSLP::collectValuesToDemote(
14002
14022
};
14003
14023
unsigned Start = 0;
14004
14024
unsigned End = I->getNumOperands();
14025
+
14026
+ auto FinalAnalysis = [&](const TreeEntry *ITE = nullptr) {
14027
+ if (!IsProfitableToDemote)
14028
+ return false;
14029
+ return (ITE && ITE->UserTreeIndices.size() > 1) ||
14030
+ IsPotentiallyTruncated(I, BitWidth);
14031
+ };
14032
+ auto ProcessOperands = [&](ArrayRef<Value *> Operands, bool &NeedToExit) {
14033
+ NeedToExit = false;
14034
+ unsigned InitLevel = MaxDepthLevel;
14035
+ for (Value *IncValue : Operands) {
14036
+ unsigned Level = InitLevel;
14037
+ if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
14038
+ ToDemote, DemotedConsts, Visited, Level,
14039
+ IsProfitableToDemote, IsTruncRoot)) {
14040
+ if (!IsProfitableToDemote)
14041
+ return false;
14042
+ NeedToExit = true;
14043
+ if (!FinalAnalysis(ITE))
14044
+ return false;
14045
+ continue;
14046
+ }
14047
+ MaxDepthLevel = std::max(MaxDepthLevel, Level);
14048
+ }
14049
+ return true;
14050
+ };
14051
+ bool NeedToExit = false;
14005
14052
switch (I->getOpcode()) {
14006
14053
14007
14054
// We can always demote truncations and extensions. Since truncations can
@@ -14027,35 +14074,21 @@ bool BoUpSLP::collectValuesToDemote(
14027
14074
case Instruction::And:
14028
14075
case Instruction::Or:
14029
14076
case Instruction::Xor: {
14030
- unsigned Level1, Level2;
14031
- if ((ITE->UserTreeIndices.size() > 1 &&
14032
- !IsPotentiallyTruncated(I, BitWidth)) ||
14033
- !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
14034
- BitWidth, ToDemote, DemotedConsts, Visited,
14035
- Level1, IsProfitableToDemote, IsTruncRoot) ||
14036
- !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
14037
- BitWidth, ToDemote, DemotedConsts, Visited,
14038
- Level2, IsProfitableToDemote, IsTruncRoot))
14077
+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14078
+ return false;
14079
+ if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
14039
14080
return false;
14040
- MaxDepthLevel = std::max(Level1, Level2);
14041
14081
break;
14042
14082
}
14043
14083
14044
14084
// We can demote selects if we can demote their true and false values.
14045
14085
case Instruction::Select: {
14086
+ if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14087
+ return false;
14046
14088
Start = 1;
14047
- unsigned Level1, Level2;
14048
- SelectInst *SI = cast<SelectInst>(I);
14049
- if ((ITE->UserTreeIndices.size() > 1 &&
14050
- !IsPotentiallyTruncated(I, BitWidth)) ||
14051
- !collectValuesToDemote(SI->getTrueValue(), IsProfitableToDemoteRoot,
14052
- BitWidth, ToDemote, DemotedConsts, Visited,
14053
- Level1, IsProfitableToDemote, IsTruncRoot) ||
14054
- !collectValuesToDemote(SI->getFalseValue(), IsProfitableToDemoteRoot,
14055
- BitWidth, ToDemote, DemotedConsts, Visited,
14056
- Level2, IsProfitableToDemote, IsTruncRoot))
14089
+ auto *SI = cast<SelectInst>(I);
14090
+ if (!ProcessOperands({SI->getTrueValue(), SI->getFalseValue()}, NeedToExit))
14057
14091
return false;
14058
- MaxDepthLevel = std::max(Level1, Level2);
14059
14092
break;
14060
14093
}
14061
14094
@@ -14066,22 +14099,20 @@ bool BoUpSLP::collectValuesToDemote(
14066
14099
MaxDepthLevel = 0;
14067
14100
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14068
14101
return false;
14069
- for (Value *IncValue : PN->incoming_values()) {
14070
- unsigned Level;
14071
- if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
14072
- ToDemote, DemotedConsts, Visited, Level,
14073
- IsProfitableToDemote, IsTruncRoot))
14074
- return false;
14075
- MaxDepthLevel = std::max(MaxDepthLevel, Level);
14076
- }
14102
+ SmallVector<Value *> Ops(PN->incoming_values().begin(),
14103
+ PN->incoming_values().end());
14104
+ if (!ProcessOperands(Ops, NeedToExit))
14105
+ return false;
14077
14106
break;
14078
14107
}
14079
14108
14080
14109
// Otherwise, conservatively give up.
14081
14110
default:
14082
14111
MaxDepthLevel = 1;
14083
- return IsProfitableToDemote && IsPotentiallyTruncated(I, BitWidth );
14112
+ return FinalAnalysis( );
14084
14113
}
14114
+ if (NeedToExit)
14115
+ return true;
14085
14116
14086
14117
++MaxDepthLevel;
14087
14118
// Gather demoted constant operands.
@@ -14120,15 +14151,17 @@ void BoUpSLP::computeMinimumValueSizes() {
14120
14151
14121
14152
// The first value node for store/insertelement is sext/zext/trunc? Skip it,
14122
14153
// resize to the final type.
14154
+ bool IsTruncRoot = false;
14123
14155
bool IsProfitableToDemoteRoot = !IsStoreOrInsertElt;
14124
14156
if (NodeIdx != 0 &&
14125
14157
VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
14126
14158
(VectorizableTree[NodeIdx]->getOpcode() == Instruction::ZExt ||
14127
14159
VectorizableTree[NodeIdx]->getOpcode() == Instruction::SExt ||
14128
14160
VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc)) {
14129
14161
assert(IsStoreOrInsertElt && "Expected store/insertelement seeded graph.");
14130
- ++ NodeIdx;
14162
+ IsTruncRoot = VectorizableTree[ NodeIdx]->getOpcode() == Instruction::Trunc ;
14131
14163
IsProfitableToDemoteRoot = true;
14164
+ ++NodeIdx;
14132
14165
}
14133
14166
14134
14167
// Analyzed in reduction already and not profitable - exit.
@@ -14260,7 +14293,6 @@ void BoUpSLP::computeMinimumValueSizes() {
14260
14293
ReductionBitWidth = bit_ceil(ReductionBitWidth);
14261
14294
}
14262
14295
bool IsTopRoot = NodeIdx == 0;
14263
- bool IsTruncRoot = false;
14264
14296
while (NodeIdx < VectorizableTree.size() &&
14265
14297
VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
14266
14298
VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc) {
0 commit comments