@@ -1148,18 +1148,6 @@ class BoUpSLP {
1148
1148
/// Construct a vectorizable tree that starts at \p Roots.
1149
1149
void buildTree(ArrayRef<Value *> Roots);
1150
1150
1151
- /// Checks if the very first tree node is going to be vectorized.
1152
- bool isVectorizedFirstNode() const {
1153
- return !VectorizableTree.empty() &&
1154
- VectorizableTree.front()->State == TreeEntry::Vectorize;
1155
- }
1156
-
1157
- /// Returns the main instruction for the very first node.
1158
- Instruction *getFirstNodeMainOp() const {
1159
- assert(!VectorizableTree.empty() && "No tree to get the first node from");
1160
- return VectorizableTree.front()->getMainOp();
1161
- }
1162
-
1163
1151
/// Returns whether the root node has in-tree uses.
1164
1152
bool doesRootHaveInTreeUses() const {
1165
1153
return !VectorizableTree.empty() &&
@@ -13340,22 +13328,7 @@ class HorizontalReduction {
13340
13328
// Estimate cost.
13341
13329
InstructionCost TreeCost = V.getTreeCost(VL);
13342
13330
InstructionCost ReductionCost =
13343
- getReductionCost(TTI, VL, ReduxWidth, RdxFMF);
13344
- if (V.isVectorizedFirstNode() && isa<LoadInst>(VL.front())) {
13345
- Instruction *MainOp = V.getFirstNodeMainOp();
13346
- for (Value *V : VL) {
13347
- auto *VI = dyn_cast<LoadInst>(V);
13348
- // Add the costs of scalar GEP pointers, to be removed from the
13349
- // code.
13350
- if (!VI || VI == MainOp)
13351
- continue;
13352
- auto *Ptr = dyn_cast<GetElementPtrInst>(VI->getPointerOperand());
13353
- if (!Ptr || !Ptr->hasOneUse() || Ptr->hasAllConstantIndices())
13354
- continue;
13355
- TreeCost -= TTI->getArithmeticInstrCost(
13356
- Instruction::Add, Ptr->getType(), TTI::TCK_RecipThroughput);
13357
- }
13358
- }
13331
+ getReductionCost(TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF);
13359
13332
InstructionCost Cost = TreeCost + ReductionCost;
13360
13333
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n");
13361
13334
if (!Cost.isValid())
@@ -13591,7 +13564,8 @@ class HorizontalReduction {
13591
13564
/// Calculate the cost of a reduction.
13592
13565
InstructionCost getReductionCost(TargetTransformInfo *TTI,
13593
13566
ArrayRef<Value *> ReducedVals,
13594
- unsigned ReduxWidth, FastMathFlags FMF) {
13567
+ bool IsCmpSelMinMax, unsigned ReduxWidth,
13568
+ FastMathFlags FMF) {
13595
13569
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
13596
13570
Value *FirstReducedVal = ReducedVals.front();
13597
13571
Type *ScalarTy = FirstReducedVal->getType();
@@ -13600,6 +13574,35 @@ class HorizontalReduction {
13600
13574
// If all of the reduced values are constant, the vector cost is 0, since
13601
13575
// the reduction value can be calculated at the compile time.
13602
13576
bool AllConsts = allConstant(ReducedVals);
13577
+ auto EvaluateScalarCost = [&](function_ref<InstructionCost()> GenCostFn) {
13578
+ InstructionCost Cost = 0;
13579
+ // Scalar cost is repeated for N-1 elements.
13580
+ int Cnt = ReducedVals.size();
13581
+ for (Value *RdxVal : ReducedVals) {
13582
+ if (Cnt == 1)
13583
+ break;
13584
+ --Cnt;
13585
+ if (RdxVal->hasNUsesOrMore(IsCmpSelMinMax ? 3 : 2)) {
13586
+ Cost += GenCostFn();
13587
+ continue;
13588
+ }
13589
+ InstructionCost ScalarCost = 0;
13590
+ for (User *U : RdxVal->users()) {
13591
+ auto *RdxOp = cast<Instruction>(U);
13592
+ if (hasRequiredNumberOfUses(IsCmpSelMinMax, RdxOp)) {
13593
+ ScalarCost += TTI->getInstructionCost(RdxOp, CostKind);
13594
+ continue;
13595
+ }
13596
+ ScalarCost = InstructionCost::getInvalid();
13597
+ break;
13598
+ }
13599
+ if (ScalarCost.isValid())
13600
+ Cost += ScalarCost;
13601
+ else
13602
+ Cost += GenCostFn();
13603
+ }
13604
+ return Cost;
13605
+ };
13603
13606
switch (RdxKind) {
13604
13607
case RecurKind::Add:
13605
13608
case RecurKind::Mul:
@@ -13612,7 +13615,9 @@ class HorizontalReduction {
13612
13615
if (!AllConsts)
13613
13616
VectorCost =
13614
13617
TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind);
13615
- ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind);
13618
+ ScalarCost = EvaluateScalarCost([&]() {
13619
+ return TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind);
13620
+ });
13616
13621
break;
13617
13622
}
13618
13623
case RecurKind::FMax:
@@ -13626,10 +13631,12 @@ class HorizontalReduction {
13626
13631
/*IsUnsigned=*/false, CostKind);
13627
13632
}
13628
13633
CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind);
13629
- ScalarCost = TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy,
13630
- SclCondTy, RdxPred, CostKind) +
13631
- TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
13632
- SclCondTy, RdxPred, CostKind);
13634
+ ScalarCost = EvaluateScalarCost([&]() {
13635
+ return TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy, SclCondTy,
13636
+ RdxPred, CostKind) +
13637
+ TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, SclCondTy,
13638
+ RdxPred, CostKind);
13639
+ });
13633
13640
break;
13634
13641
}
13635
13642
case RecurKind::SMax:
@@ -13646,18 +13653,18 @@ class HorizontalReduction {
13646
13653
IsUnsigned, CostKind);
13647
13654
}
13648
13655
CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind);
13649
- ScalarCost = TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy,
13650
- SclCondTy, RdxPred, CostKind) +
13651
- TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy,
13652
- SclCondTy, RdxPred, CostKind);
13656
+ ScalarCost = EvaluateScalarCost([&]() {
13657
+ return TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy, SclCondTy,
13658
+ RdxPred, CostKind) +
13659
+ TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, SclCondTy,
13660
+ RdxPred, CostKind);
13661
+ });
13653
13662
break;
13654
13663
}
13655
13664
default:
13656
13665
llvm_unreachable("Expected arithmetic or min/max reduction operation");
13657
13666
}
13658
13667
13659
- // Scalar cost is repeated for N-1 elements.
13660
- ScalarCost *= (ReduxWidth - 1);
13661
13668
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost
13662
13669
<< " for reduction that starts with " << *FirstReducedVal
13663
13670
<< " (It is a splitting reduction)\n");
0 commit comments