Skip to content

Commit e4b7724

Browse files
committed
[SLP]Do extra analysis int minbitwidth if some checks return false.
The instruction itself can be considered good for minbitwidth casting, even if one of the operand checks returns false. Reviewers: RKSimon Reviewed By: RKSimon Pull Request: #84363
1 parent af2bf86 commit e4b7724

File tree

2 files changed

+80
-46
lines changed

2 files changed

+80
-46
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 69 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10226,9 +10226,11 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
1022610226
for (const TreeEntry *TE : ForRemoval)
1022710227
Set.erase(TE);
1022810228
}
10229+
bool NeedToRemapValues = false;
1022910230
for (auto *It = UsedTEs.begin(); It != UsedTEs.end();) {
1023010231
if (It->empty()) {
1023110232
UsedTEs.erase(It);
10233+
NeedToRemapValues = true;
1023210234
continue;
1023310235
}
1023410236
std::advance(It, 1);
@@ -10237,6 +10239,19 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
1023710239
Entries.clear();
1023810240
return std::nullopt;
1023910241
}
10242+
// Recalculate the mapping between the values and entries sets.
10243+
if (NeedToRemapValues) {
10244+
DenseMap<Value *, int> PrevUsedValuesEntry;
10245+
PrevUsedValuesEntry.swap(UsedValuesEntry);
10246+
for (auto [Idx, Set] : enumerate(UsedTEs)) {
10247+
DenseSet<Value *> Values;
10248+
for (const TreeEntry *E : Set)
10249+
Values.insert(E->Scalars.begin(), E->Scalars.end());
10250+
for (const auto &P : PrevUsedValuesEntry)
10251+
if (Values.contains(P.first))
10252+
UsedValuesEntry.try_emplace(P.first, Idx);
10253+
}
10254+
}
1024010255
}
1024110256

1024210257
unsigned VF = 0;
@@ -11935,7 +11950,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1193511950
Builder.SetCurrentDebugLocation(PH->getDebugLoc());
1193611951
Value *Vec = vectorizeOperand(E, I, /*PostponedPHIs=*/true);
1193711952
if (VecTy != Vec->getType()) {
11938-
assert((getOperandEntry(E, I)->State == TreeEntry::NeedToGather ||
11953+
assert((It != MinBWs.end() ||
11954+
getOperandEntry(E, I)->State == TreeEntry::NeedToGather ||
1193911955
MinBWs.contains(getOperandEntry(E, I))) &&
1194011956
"Expected item in MinBWs.");
1194111957
Vec = Builder.CreateIntCast(Vec, VecTy, It->second.second);
@@ -12193,7 +12209,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1219312209
return E->VectorizedValue;
1219412210
}
1219512211
if (L->getType() != R->getType()) {
12196-
assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12212+
assert((It != MinBWs.end() ||
12213+
getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
1219712214
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
1219812215
MinBWs.contains(getOperandEntry(E, 0)) ||
1219912216
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12232,7 +12249,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1223212249
return E->VectorizedValue;
1223312250
}
1223412251
if (True->getType() != False->getType()) {
12235-
assert((getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
12252+
assert((It != MinBWs.end() ||
12253+
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
1223612254
getOperandEntry(E, 2)->State == TreeEntry::NeedToGather ||
1223712255
MinBWs.contains(getOperandEntry(E, 1)) ||
1223812256
MinBWs.contains(getOperandEntry(E, 2))) &&
@@ -12302,7 +12320,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1230212320
return E->VectorizedValue;
1230312321
}
1230412322
if (LHS->getType() != RHS->getType()) {
12305-
assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12323+
assert((It != MinBWs.end() ||
12324+
getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
1230612325
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
1230712326
MinBWs.contains(getOperandEntry(E, 0)) ||
1230812327
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -12540,7 +12559,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1254012559
return E->VectorizedValue;
1254112560
}
1254212561
if (LHS && RHS && LHS->getType() != RHS->getType()) {
12543-
assert((getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
12562+
assert((It != MinBWs.end() ||
12563+
getOperandEntry(E, 0)->State == TreeEntry::NeedToGather ||
1254412564
getOperandEntry(E, 1)->State == TreeEntry::NeedToGather ||
1254512565
MinBWs.contains(getOperandEntry(E, 0)) ||
1254612566
MinBWs.contains(getOperandEntry(E, 1))) &&
@@ -14002,6 +14022,33 @@ bool BoUpSLP::collectValuesToDemote(
1400214022
};
1400314023
unsigned Start = 0;
1400414024
unsigned End = I->getNumOperands();
14025+
14026+
auto FinalAnalysis = [&](const TreeEntry *ITE = nullptr) {
14027+
if (!IsProfitableToDemote)
14028+
return false;
14029+
return (ITE && ITE->UserTreeIndices.size() > 1) ||
14030+
IsPotentiallyTruncated(I, BitWidth);
14031+
};
14032+
auto ProcessOperands = [&](ArrayRef<Value *> Operands, bool &NeedToExit) {
14033+
NeedToExit = false;
14034+
unsigned InitLevel = MaxDepthLevel;
14035+
for (Value *IncValue : Operands) {
14036+
unsigned Level = InitLevel;
14037+
if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
14038+
ToDemote, DemotedConsts, Visited, Level,
14039+
IsProfitableToDemote, IsTruncRoot)) {
14040+
if (!IsProfitableToDemote)
14041+
return false;
14042+
NeedToExit = true;
14043+
if (!FinalAnalysis(ITE))
14044+
return false;
14045+
continue;
14046+
}
14047+
MaxDepthLevel = std::max(MaxDepthLevel, Level);
14048+
}
14049+
return true;
14050+
};
14051+
bool NeedToExit = false;
1400514052
switch (I->getOpcode()) {
1400614053

1400714054
// We can always demote truncations and extensions. Since truncations can
@@ -14027,35 +14074,21 @@ bool BoUpSLP::collectValuesToDemote(
1402714074
case Instruction::And:
1402814075
case Instruction::Or:
1402914076
case Instruction::Xor: {
14030-
unsigned Level1, Level2;
14031-
if ((ITE->UserTreeIndices.size() > 1 &&
14032-
!IsPotentiallyTruncated(I, BitWidth)) ||
14033-
!collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
14034-
BitWidth, ToDemote, DemotedConsts, Visited,
14035-
Level1, IsProfitableToDemote, IsTruncRoot) ||
14036-
!collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
14037-
BitWidth, ToDemote, DemotedConsts, Visited,
14038-
Level2, IsProfitableToDemote, IsTruncRoot))
14077+
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14078+
return false;
14079+
if (!ProcessOperands({I->getOperand(0), I->getOperand(1)}, NeedToExit))
1403914080
return false;
14040-
MaxDepthLevel = std::max(Level1, Level2);
1404114081
break;
1404214082
}
1404314083

1404414084
// We can demote selects if we can demote their true and false values.
1404514085
case Instruction::Select: {
14086+
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
14087+
return false;
1404614088
Start = 1;
14047-
unsigned Level1, Level2;
14048-
SelectInst *SI = cast<SelectInst>(I);
14049-
if ((ITE->UserTreeIndices.size() > 1 &&
14050-
!IsPotentiallyTruncated(I, BitWidth)) ||
14051-
!collectValuesToDemote(SI->getTrueValue(), IsProfitableToDemoteRoot,
14052-
BitWidth, ToDemote, DemotedConsts, Visited,
14053-
Level1, IsProfitableToDemote, IsTruncRoot) ||
14054-
!collectValuesToDemote(SI->getFalseValue(), IsProfitableToDemoteRoot,
14055-
BitWidth, ToDemote, DemotedConsts, Visited,
14056-
Level2, IsProfitableToDemote, IsTruncRoot))
14089+
auto *SI = cast<SelectInst>(I);
14090+
if (!ProcessOperands({SI->getTrueValue(), SI->getFalseValue()}, NeedToExit))
1405714091
return false;
14058-
MaxDepthLevel = std::max(Level1, Level2);
1405914092
break;
1406014093
}
1406114094

@@ -14066,22 +14099,20 @@ bool BoUpSLP::collectValuesToDemote(
1406614099
MaxDepthLevel = 0;
1406714100
if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
1406814101
return false;
14069-
for (Value *IncValue : PN->incoming_values()) {
14070-
unsigned Level;
14071-
if (!collectValuesToDemote(IncValue, IsProfitableToDemoteRoot, BitWidth,
14072-
ToDemote, DemotedConsts, Visited, Level,
14073-
IsProfitableToDemote, IsTruncRoot))
14074-
return false;
14075-
MaxDepthLevel = std::max(MaxDepthLevel, Level);
14076-
}
14102+
SmallVector<Value *> Ops(PN->incoming_values().begin(),
14103+
PN->incoming_values().end());
14104+
if (!ProcessOperands(Ops, NeedToExit))
14105+
return false;
1407714106
break;
1407814107
}
1407914108

1408014109
// Otherwise, conservatively give up.
1408114110
default:
1408214111
MaxDepthLevel = 1;
14083-
return IsProfitableToDemote && IsPotentiallyTruncated(I, BitWidth);
14112+
return FinalAnalysis();
1408414113
}
14114+
if (NeedToExit)
14115+
return true;
1408514116

1408614117
++MaxDepthLevel;
1408714118
// Gather demoted constant operands.
@@ -14120,15 +14151,17 @@ void BoUpSLP::computeMinimumValueSizes() {
1412014151

1412114152
// The first value node for store/insertelement is sext/zext/trunc? Skip it,
1412214153
// resize to the final type.
14154+
bool IsTruncRoot = false;
1412314155
bool IsProfitableToDemoteRoot = !IsStoreOrInsertElt;
1412414156
if (NodeIdx != 0 &&
1412514157
VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
1412614158
(VectorizableTree[NodeIdx]->getOpcode() == Instruction::ZExt ||
1412714159
VectorizableTree[NodeIdx]->getOpcode() == Instruction::SExt ||
1412814160
VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc)) {
1412914161
assert(IsStoreOrInsertElt && "Expected store/insertelement seeded graph.");
14130-
++NodeIdx;
14162+
IsTruncRoot = VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc;
1413114163
IsProfitableToDemoteRoot = true;
14164+
++NodeIdx;
1413214165
}
1413314166

1413414167
// Analyzed in reduction already and not profitable - exit.
@@ -14260,7 +14293,6 @@ void BoUpSLP::computeMinimumValueSizes() {
1426014293
ReductionBitWidth = bit_ceil(ReductionBitWidth);
1426114294
}
1426214295
bool IsTopRoot = NodeIdx == 0;
14263-
bool IsTruncRoot = false;
1426414296
while (NodeIdx < VectorizableTree.size() &&
1426514297
VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize &&
1426614298
VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc) {

llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ for.end: ; preds = %for.end.loopexit, %
228228
; YAML-NEXT: Function: test_unrolled_select
229229
; YAML-NEXT: Args:
230230
; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost '
231-
; YAML-NEXT: - Cost: '-36'
231+
; YAML-NEXT: - Cost: '-40'
232232
; YAML-NEXT: - String: ' and with tree size '
233233
; YAML-NEXT: - TreeSize: '10'
234234

@@ -246,15 +246,17 @@ define i32 @test_unrolled_select(ptr noalias nocapture readonly %blk1, ptr noali
246246
; CHECK-NEXT: [[P2_045:%.*]] = phi ptr [ [[BLK2:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR88:%.*]], [[IF_END_86]] ]
247247
; CHECK-NEXT: [[P1_044:%.*]] = phi ptr [ [[BLK1:%.*]], [[FOR_BODY_LR_PH]] ], [ [[ADD_PTR:%.*]], [[IF_END_86]] ]
248248
; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[P1_044]], align 1
249-
; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i32>
249+
; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i8> [[TMP0]] to <8 x i16>
250250
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i8>, ptr [[P2_045]], align 1
251-
; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i32>
252-
; CHECK-NEXT: [[TMP4:%.*]] = sub nsw <8 x i32> [[TMP1]], [[TMP3]]
253-
; CHECK-NEXT: [[TMP5:%.*]] = icmp slt <8 x i32> [[TMP4]], zeroinitializer
254-
; CHECK-NEXT: [[TMP6:%.*]] = sub nsw <8 x i32> zeroinitializer, [[TMP4]]
255-
; CHECK-NEXT: [[TMP7:%.*]] = select <8 x i1> [[TMP5]], <8 x i32> [[TMP6]], <8 x i32> [[TMP4]]
256-
; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP7]])
257-
; CHECK-NEXT: [[OP_RDX]] = add i32 [[TMP8]], [[S_047]]
251+
; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i8> [[TMP2]] to <8 x i16>
252+
; CHECK-NEXT: [[TMP4:%.*]] = sub <8 x i16> [[TMP1]], [[TMP3]]
253+
; CHECK-NEXT: [[TMP5:%.*]] = trunc <8 x i16> [[TMP4]] to <8 x i1>
254+
; CHECK-NEXT: [[TMP6:%.*]] = icmp slt <8 x i1> [[TMP5]], zeroinitializer
255+
; CHECK-NEXT: [[TMP7:%.*]] = sub <8 x i16> zeroinitializer, [[TMP4]]
256+
; CHECK-NEXT: [[TMP8:%.*]] = select <8 x i1> [[TMP6]], <8 x i16> [[TMP7]], <8 x i16> [[TMP4]]
257+
; CHECK-NEXT: [[TMP9:%.*]] = zext <8 x i16> [[TMP8]] to <8 x i32>
258+
; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP9]])
259+
; CHECK-NEXT: [[OP_RDX]] = add i32 [[TMP10]], [[S_047]]
258260
; CHECK-NEXT: [[CMP83:%.*]] = icmp slt i32 [[OP_RDX]], [[LIM:%.*]]
259261
; CHECK-NEXT: br i1 [[CMP83]], label [[IF_END_86]], label [[FOR_END_LOOPEXIT:%.*]]
260262
; CHECK: if.end.86:

0 commit comments

Comments
 (0)