Skip to content

Commit 03f22b0

Browse files
committed
[SLP] Remove LHS and RHS from OperationData.
These were only really used for 2 things. One was to check if the operand matches the phi if it exists. The other was for the createOp method to build the reduction. For the first case we still have the operation we just need to know how to index its operands. So I've modified getLHS/getRHS to just use the opcode/kind to know how to find the right operands on an instruction that is now passed in. For the other case we had to create an OperationData object to set the LHS/RHS values and copy the opcode/kind from another object. We would then just call createOp on that temporary object. Instead I've made LHS/RHS arguments to createOp and removed all these temporary objects. Differential Revision: https://reviews.llvm.org/D88193
1 parent d1419c9 commit 03f22b0

File tree

1 file changed

+50
-67
lines changed

1 file changed

+50
-67
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 50 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -6312,20 +6312,13 @@ class HorizontalReduction {
63126312
/// Opcode of the instruction.
63136313
unsigned Opcode = 0;
63146314

6315-
/// Left operand of the reduction operation.
6316-
Value *LHS = nullptr;
6317-
6318-
/// Right operand of the reduction operation.
6319-
Value *RHS = nullptr;
6320-
63216315
/// Kind of the reduction operation.
63226316
ReductionKind Kind = RK_None;
63236317

63246318
/// Checks if the reduction operation can be vectorized.
63256319
bool isVectorizable() const {
6326-
return LHS && RHS &&
6327-
// We currently only support add/mul/logical && min/max reductions.
6328-
((Kind == RK_Arithmetic &&
6320+
// We currently only support add/mul/logical && min/max reductions.
6321+
return ((Kind == RK_Arithmetic &&
63296322
(Opcode == Instruction::Add || Opcode == Instruction::FAdd ||
63306323
Opcode == Instruction::Mul || Opcode == Instruction::FMul ||
63316324
Opcode == Instruction::And || Opcode == Instruction::Or ||
@@ -6336,7 +6329,8 @@ class HorizontalReduction {
63366329
}
63376330

63386331
/// Creates reduction operation with the current opcode.
6339-
Value *createOp(IRBuilder<> &Builder, const Twine &Name) const {
6332+
Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
6333+
const Twine &Name) const {
63406334
assert(isVectorizable() &&
63416335
"Expected add|fadd or min/max reduction operation.");
63426336
Value *Cmp = nullptr;
@@ -6377,8 +6371,8 @@ class HorizontalReduction {
63776371

63786372
/// Constructor for reduction operations with opcode and its left and
63796373
/// right operands.
6380-
OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind)
6381-
: Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) {
6374+
OperationData(unsigned Opcode, ReductionKind Kind)
6375+
: Opcode(Opcode), Kind(Kind) {
63826376
assert(Kind != RK_None && "One of the reduction operations is expected.");
63836377
}
63846378

@@ -6411,16 +6405,14 @@ class HorizontalReduction {
64116405

64126406
/// Total number of operands in the reduction operation.
64136407
unsigned getNumberOfOperands() const {
6414-
assert(Kind != RK_None && !!*this && LHS && RHS &&
6415-
"Expected reduction operation.");
6408+
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
64166409
return isMinMax() ? 3 : 2;
64176410
}
64186411

64196412
/// Checks if the instruction is in basic block \p BB.
64206413
/// For a min/max reduction check that both compare and select are in \p BB.
64216414
bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) const {
6422-
assert(Kind != RK_None && !!*this && LHS && RHS &&
6423-
"Expected reduction operation.");
6415+
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
64246416
if (IsRedOp && isMinMax()) {
64256417
auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
64266418
return I->getParent() == BB && Cmp && Cmp->getParent() == BB;
@@ -6430,8 +6422,7 @@ class HorizontalReduction {
64306422

64316423
/// Expected number of uses for reduction operations/reduced values.
64326424
bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const {
6433-
assert(Kind != RK_None && !!*this && LHS && RHS &&
6434-
"Expected reduction operation.");
6425+
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
64356426
// SelectInst must be used twice while the condition op must have single
64366427
// use only.
64376428
if (isMinMax())
@@ -6445,8 +6436,7 @@ class HorizontalReduction {
64456436

64466437
/// Initializes the list of reduction operations.
64476438
void initReductionOps(ReductionOpsListType &ReductionOps) {
6448-
assert(Kind != RK_None && !!*this && LHS && RHS &&
6449-
"Expected reduction operation.");
6439+
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
64506440
if (isMinMax())
64516441
ReductionOps.assign(2, ReductionOpsType());
64526442
else
@@ -6455,8 +6445,7 @@ class HorizontalReduction {
64556445

64566446
/// Add all reduction operations for the reduction instruction \p I.
64576447
void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) {
6458-
assert(Kind != RK_None && !!*this && LHS && RHS &&
6459-
"Expected reduction operation.");
6448+
assert(Kind != RK_None && !!*this && "Expected reduction operation.");
64606449
if (isMinMax()) {
64616450
ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
64626451
ReductionOps[1].emplace_back(I);
@@ -6467,8 +6456,7 @@ class HorizontalReduction {
64676456

64686457
/// Checks if instruction is associative and can be vectorized.
64696458
bool isAssociative(Instruction *I) const {
6470-
assert(Kind != RK_None && *this && LHS && RHS &&
6471-
"Expected reduction operation.");
6459+
assert(Kind != RK_None && *this && "Expected reduction operation.");
64726460
switch (Kind) {
64736461
case RK_Arithmetic:
64746462
return I->isAssociative();
@@ -6493,15 +6481,13 @@ class HorizontalReduction {
64936481
/// Checks if two operation data are both a reduction op or both a reduced
64946482
/// value.
64956483
bool operator==(const OperationData &OD) const {
6496-
assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) &&
6484+
assert(((Kind != OD.Kind) || (Opcode != 0 && OD.Opcode != 0)) &&
64976485
"One of the comparing operations is incorrect.");
6498-
return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode);
6486+
return Kind == OD.Kind && Opcode == OD.Opcode;
64996487
}
65006488
bool operator!=(const OperationData &OD) const { return !(*this == OD); }
65016489
void clear() {
65026490
Opcode = 0;
6503-
LHS = nullptr;
6504-
RHS = nullptr;
65056491
Kind = RK_None;
65066492
}
65076493

@@ -6513,19 +6499,25 @@ class HorizontalReduction {
65136499

65146500
/// Get kind of reduction data.
65156501
ReductionKind getKind() const { return Kind; }
6516-
Value *getLHS() const { return LHS; }
6517-
Value *getRHS() const { return RHS; }
6518-
Type *getConditionType() const {
6519-
return isMinMax() ? CmpInst::makeCmpResultType(LHS->getType()) : nullptr;
6502+
Value *getLHS(Instruction *I) const {
6503+
if (Kind == RK_None)
6504+
return nullptr;
6505+
return I->getOperand(getFirstOperandIndex());
6506+
}
6507+
Value *getRHS(Instruction *I) const {
6508+
if (Kind == RK_None)
6509+
return nullptr;
6510+
return I->getOperand(getFirstOperandIndex() + 1);
65206511
}
65216512

65226513
/// Creates reduction operation with the current opcode with the IR flags
65236514
/// from \p ReductionOps.
6524-
Value *createOp(IRBuilder<> &Builder, const Twine &Name,
6515+
Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
6516+
const Twine &Name,
65256517
const ReductionOpsListType &ReductionOps) const {
65266518
assert(isVectorizable() &&
65276519
"Expected add|fadd or min/max reduction operation.");
6528-
auto *Op = createOp(Builder, Name);
6520+
auto *Op = createOp(Builder, LHS, RHS, Name);
65296521
switch (Kind) {
65306522
case RK_Arithmetic:
65316523
propagateIRFlags(Op, ReductionOps[0]);
@@ -6545,11 +6537,11 @@ class HorizontalReduction {
65456537
}
65466538
/// Creates reduction operation with the current opcode with the IR flags
65476539
/// from \p I.
6548-
Value *createOp(IRBuilder<> &Builder, const Twine &Name,
6549-
Instruction *I) const {
6540+
Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS,
6541+
const Twine &Name, Instruction *I) const {
65506542
assert(isVectorizable() &&
65516543
"Expected add|fadd or min/max reduction operation.");
6552-
auto *Op = createOp(Builder, Name);
6544+
auto *Op = createOp(Builder, LHS, RHS, Name);
65536545
switch (Kind) {
65546546
case RK_Arithmetic:
65556547
propagateIRFlags(Op, I);
@@ -6637,19 +6629,18 @@ class HorizontalReduction {
66376629
Value *LHS;
66386630
Value *RHS;
66396631
if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(I)) {
6640-
return OperationData(cast<BinaryOperator>(I)->getOpcode(), LHS, RHS,
6641-
RK_Arithmetic);
6632+
return OperationData(cast<BinaryOperator>(I)->getOpcode(), RK_Arithmetic);
66426633
}
66436634
if (auto *Select = dyn_cast<SelectInst>(I)) {
66446635
// Look for a min/max pattern.
66456636
if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
6646-
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin);
6637+
return OperationData(Instruction::ICmp, RK_UMin);
66476638
} else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) {
6648-
return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin);
6639+
return OperationData(Instruction::ICmp, RK_SMin);
66496640
} else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
6650-
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax);
6641+
return OperationData(Instruction::ICmp, RK_UMax);
66516642
} else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) {
6652-
return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax);
6643+
return OperationData(Instruction::ICmp, RK_SMax);
66536644
} else {
66546645
// Try harder: look for min/max pattern based on instructions producing
66556646
// same values such as: select ((cmp Inst1, Inst2), Inst1, Inst2).
@@ -6693,19 +6684,19 @@ class HorizontalReduction {
66936684

66946685
case CmpInst::ICMP_ULT:
66956686
case CmpInst::ICMP_ULE:
6696-
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin);
6687+
return OperationData(Instruction::ICmp, RK_UMin);
66976688

66986689
case CmpInst::ICMP_SLT:
66996690
case CmpInst::ICMP_SLE:
6700-
return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin);
6691+
return OperationData(Instruction::ICmp, RK_SMin);
67016692

67026693
case CmpInst::ICMP_UGT:
67036694
case CmpInst::ICMP_UGE:
6704-
return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax);
6695+
return OperationData(Instruction::ICmp, RK_UMax);
67056696

67066697
case CmpInst::ICMP_SGT:
67076698
case CmpInst::ICMP_SGE:
6708-
return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax);
6699+
return OperationData(Instruction::ICmp, RK_SMax);
67096700
}
67106701
}
67116702
}
@@ -6726,13 +6717,13 @@ class HorizontalReduction {
67266717
// r *= v1 + v2 + v3 + v4
67276718
// In such a case start looking for a tree rooted in the first '+'.
67286719
if (Phi) {
6729-
if (ReductionData.getLHS() == Phi) {
6720+
if (ReductionData.getLHS(B) == Phi) {
67306721
Phi = nullptr;
6731-
B = dyn_cast<Instruction>(ReductionData.getRHS());
6722+
B = dyn_cast<Instruction>(ReductionData.getRHS(B));
67326723
ReductionData = getOperationData(B);
6733-
} else if (ReductionData.getRHS() == Phi) {
6724+
} else if (ReductionData.getRHS(B) == Phi) {
67346725
Phi = nullptr;
6735-
B = dyn_cast<Instruction>(ReductionData.getLHS());
6726+
B = dyn_cast<Instruction>(ReductionData.getLHS(B));
67366727
ReductionData = getOperationData(B);
67376728
}
67386729
}
@@ -6984,11 +6975,8 @@ class HorizontalReduction {
69846975
} else {
69856976
// Update the final value in the reduction.
69866977
Builder.SetCurrentDebugLocation(Loc);
6987-
OperationData VectReductionData(ReductionData.getOpcode(),
6988-
VectorizedTree, ReducedSubTree,
6989-
ReductionData.getKind());
6990-
VectorizedTree =
6991-
VectReductionData.createOp(Builder, "op.rdx", ReductionOps);
6978+
VectorizedTree = ReductionData.createOp(
6979+
Builder, VectorizedTree, ReducedSubTree, "op.rdx", ReductionOps);
69926980
}
69936981
i += ReduxWidth;
69946982
ReduxWidth = PowerOf2Floor(NumReducedVals - i);
@@ -6999,19 +6987,15 @@ class HorizontalReduction {
69996987
for (; i < NumReducedVals; ++i) {
70006988
auto *I = cast<Instruction>(ReducedVals[i]);
70016989
Builder.SetCurrentDebugLocation(I->getDebugLoc());
7002-
OperationData VectReductionData(ReductionData.getOpcode(),
7003-
VectorizedTree, I,
7004-
ReductionData.getKind());
7005-
VectorizedTree = VectReductionData.createOp(Builder, "", ReductionOps);
6990+
VectorizedTree = ReductionData.createOp(Builder, VectorizedTree, I, "",
6991+
ReductionOps);
70066992
}
70076993
for (auto &Pair : ExternallyUsedValues) {
70086994
// Add each externally used value to the final reduction.
70096995
for (auto *I : Pair.second) {
70106996
Builder.SetCurrentDebugLocation(I->getDebugLoc());
7011-
OperationData VectReductionData(ReductionData.getOpcode(),
7012-
VectorizedTree, Pair.first,
7013-
ReductionData.getKind());
7014-
VectorizedTree = VectReductionData.createOp(Builder, "op.extra", I);
6997+
VectorizedTree = ReductionData.createOp(Builder, VectorizedTree,
6998+
Pair.first, "op.extra", I);
70156999
}
70167000
}
70177001

@@ -7133,9 +7117,8 @@ class HorizontalReduction {
71337117
Builder.CreateShuffleVector(TmpVec, LeftMask, "rdx.shuf.l");
71347118
Value *RightShuf =
71357119
Builder.CreateShuffleVector(TmpVec, RightMask, "rdx.shuf.r");
7136-
OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf,
7137-
RightShuf, ReductionData.getKind());
7138-
TmpVec = VectReductionData.createOp(Builder, "op.rdx", ReductionOps);
7120+
TmpVec = ReductionData.createOp(Builder, LeftShuf, RightShuf, "op.rdx",
7121+
ReductionOps);
71397122
}
71407123

71417124
// The result is in the first element of the vector.

0 commit comments

Comments
 (0)