@@ -6312,20 +6312,13 @@ class HorizontalReduction {
6312
6312
// / Opcode of the instruction.
6313
6313
unsigned Opcode = 0 ;
6314
6314
6315
- // / Left operand of the reduction operation.
6316
- Value *LHS = nullptr ;
6317
-
6318
- // / Right operand of the reduction operation.
6319
- Value *RHS = nullptr ;
6320
-
6321
6315
// / Kind of the reduction operation.
6322
6316
ReductionKind Kind = RK_None;
6323
6317
6324
6318
// / Checks if the reduction operation can be vectorized.
6325
6319
bool isVectorizable () const {
6326
- return LHS && RHS &&
6327
- // We currently only support add/mul/logical && min/max reductions.
6328
- ((Kind == RK_Arithmetic &&
6320
+ // We currently only support add/mul/logical && min/max reductions.
6321
+ return ((Kind == RK_Arithmetic &&
6329
6322
(Opcode == Instruction::Add || Opcode == Instruction::FAdd ||
6330
6323
Opcode == Instruction::Mul || Opcode == Instruction::FMul ||
6331
6324
Opcode == Instruction::And || Opcode == Instruction::Or ||
@@ -6336,7 +6329,8 @@ class HorizontalReduction {
6336
6329
}
6337
6330
6338
6331
// / Creates reduction operation with the current opcode.
6339
- Value *createOp (IRBuilder<> &Builder, const Twine &Name) const {
6332
+ Value *createOp (IRBuilder<> &Builder, Value *LHS, Value *RHS,
6333
+ const Twine &Name) const {
6340
6334
assert (isVectorizable () &&
6341
6335
" Expected add|fadd or min/max reduction operation." );
6342
6336
Value *Cmp = nullptr ;
@@ -6377,8 +6371,8 @@ class HorizontalReduction {
6377
6371
6378
6372
// / Constructor for reduction operations with opcode and its left and
6379
6373
// / right operands.
6380
- OperationData (unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind)
6381
- : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) {
6374
+ OperationData (unsigned Opcode, ReductionKind Kind)
6375
+ : Opcode(Opcode), Kind(Kind) {
6382
6376
assert (Kind != RK_None && " One of the reduction operations is expected." );
6383
6377
}
6384
6378
@@ -6411,16 +6405,14 @@ class HorizontalReduction {
6411
6405
6412
6406
// / Total number of operands in the reduction operation.
6413
6407
unsigned getNumberOfOperands () const {
6414
- assert (Kind != RK_None && !!*this && LHS && RHS &&
6415
- " Expected reduction operation." );
6408
+ assert (Kind != RK_None && !!*this && " Expected reduction operation." );
6416
6409
return isMinMax () ? 3 : 2 ;
6417
6410
}
6418
6411
6419
6412
// / Checks if the instruction is in basic block \p BB.
6420
6413
// / For a min/max reduction check that both compare and select are in \p BB.
6421
6414
bool hasSameParent (Instruction *I, BasicBlock *BB, bool IsRedOp) const {
6422
- assert (Kind != RK_None && !!*this && LHS && RHS &&
6423
- " Expected reduction operation." );
6415
+ assert (Kind != RK_None && !!*this && " Expected reduction operation." );
6424
6416
if (IsRedOp && isMinMax ()) {
6425
6417
auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition ());
6426
6418
return I->getParent () == BB && Cmp && Cmp->getParent () == BB;
@@ -6430,8 +6422,7 @@ class HorizontalReduction {
6430
6422
6431
6423
// / Expected number of uses for reduction operations/reduced values.
6432
6424
bool hasRequiredNumberOfUses (Instruction *I, bool IsReductionOp) const {
6433
- assert (Kind != RK_None && !!*this && LHS && RHS &&
6434
- " Expected reduction operation." );
6425
+ assert (Kind != RK_None && !!*this && " Expected reduction operation." );
6435
6426
// SelectInst must be used twice while the condition op must have single
6436
6427
// use only.
6437
6428
if (isMinMax ())
@@ -6445,8 +6436,7 @@ class HorizontalReduction {
6445
6436
6446
6437
// / Initializes the list of reduction operations.
6447
6438
void initReductionOps (ReductionOpsListType &ReductionOps) {
6448
- assert (Kind != RK_None && !!*this && LHS && RHS &&
6449
- " Expected reduction operation." );
6439
+ assert (Kind != RK_None && !!*this && " Expected reduction operation." );
6450
6440
if (isMinMax ())
6451
6441
ReductionOps.assign (2 , ReductionOpsType ());
6452
6442
else
@@ -6455,8 +6445,7 @@ class HorizontalReduction {
6455
6445
6456
6446
// / Add all reduction operations for the reduction instruction \p I.
6457
6447
void addReductionOps (Instruction *I, ReductionOpsListType &ReductionOps) {
6458
- assert (Kind != RK_None && !!*this && LHS && RHS &&
6459
- " Expected reduction operation." );
6448
+ assert (Kind != RK_None && !!*this && " Expected reduction operation." );
6460
6449
if (isMinMax ()) {
6461
6450
ReductionOps[0 ].emplace_back (cast<SelectInst>(I)->getCondition ());
6462
6451
ReductionOps[1 ].emplace_back (I);
@@ -6467,8 +6456,7 @@ class HorizontalReduction {
6467
6456
6468
6457
// / Checks if instruction is associative and can be vectorized.
6469
6458
bool isAssociative (Instruction *I) const {
6470
- assert (Kind != RK_None && *this && LHS && RHS &&
6471
- " Expected reduction operation." );
6459
+ assert (Kind != RK_None && *this && " Expected reduction operation." );
6472
6460
switch (Kind) {
6473
6461
case RK_Arithmetic:
6474
6462
return I->isAssociative ();
@@ -6493,15 +6481,13 @@ class HorizontalReduction {
6493
6481
// / Checks if two operation data are both a reduction op or both a reduced
6494
6482
// / value.
6495
6483
bool operator ==(const OperationData &OD) const {
6496
- assert (((Kind != OD.Kind ) || ((!LHS == !OD. LHS ) && (!RHS == !OD. RHS ) )) &&
6484
+ assert (((Kind != OD.Kind ) || (Opcode != 0 && OD. Opcode != 0 )) &&
6497
6485
" One of the comparing operations is incorrect." );
6498
- return this == &OD || ( Kind == OD.Kind && Opcode == OD.Opcode ) ;
6486
+ return Kind == OD.Kind && Opcode == OD.Opcode ;
6499
6487
}
6500
6488
bool operator !=(const OperationData &OD) const { return !(*this == OD); }
6501
6489
void clear () {
6502
6490
Opcode = 0 ;
6503
- LHS = nullptr ;
6504
- RHS = nullptr ;
6505
6491
Kind = RK_None;
6506
6492
}
6507
6493
@@ -6513,19 +6499,25 @@ class HorizontalReduction {
6513
6499
6514
6500
// / Get kind of reduction data.
6515
6501
ReductionKind getKind () const { return Kind; }
6516
- Value *getLHS () const { return LHS; }
6517
- Value *getRHS () const { return RHS; }
6518
- Type *getConditionType () const {
6519
- return isMinMax () ? CmpInst::makeCmpResultType (LHS->getType ()) : nullptr ;
6502
+ Value *getLHS (Instruction *I) const {
6503
+ if (Kind == RK_None)
6504
+ return nullptr ;
6505
+ return I->getOperand (getFirstOperandIndex ());
6506
+ }
6507
+ Value *getRHS (Instruction *I) const {
6508
+ if (Kind == RK_None)
6509
+ return nullptr ;
6510
+ return I->getOperand (getFirstOperandIndex () + 1 );
6520
6511
}
6521
6512
6522
6513
// / Creates reduction operation with the current opcode with the IR flags
6523
6514
// / from \p ReductionOps.
6524
- Value *createOp (IRBuilder<> &Builder, const Twine &Name,
6515
+ Value *createOp (IRBuilder<> &Builder, Value *LHS, Value *RHS,
6516
+ const Twine &Name,
6525
6517
const ReductionOpsListType &ReductionOps) const {
6526
6518
assert (isVectorizable () &&
6527
6519
" Expected add|fadd or min/max reduction operation." );
6528
- auto *Op = createOp (Builder, Name);
6520
+ auto *Op = createOp (Builder, LHS, RHS, Name);
6529
6521
switch (Kind) {
6530
6522
case RK_Arithmetic:
6531
6523
propagateIRFlags (Op, ReductionOps[0 ]);
@@ -6545,11 +6537,11 @@ class HorizontalReduction {
6545
6537
}
6546
6538
// / Creates reduction operation with the current opcode with the IR flags
6547
6539
// / from \p I.
6548
- Value *createOp (IRBuilder<> &Builder, const Twine &Name ,
6549
- Instruction *I) const {
6540
+ Value *createOp (IRBuilder<> &Builder, Value *LHS, Value *RHS ,
6541
+ const Twine &Name, Instruction *I) const {
6550
6542
assert (isVectorizable () &&
6551
6543
" Expected add|fadd or min/max reduction operation." );
6552
- auto *Op = createOp (Builder, Name);
6544
+ auto *Op = createOp (Builder, LHS, RHS, Name);
6553
6545
switch (Kind) {
6554
6546
case RK_Arithmetic:
6555
6547
propagateIRFlags (Op, I);
@@ -6637,19 +6629,18 @@ class HorizontalReduction {
6637
6629
Value *LHS;
6638
6630
Value *RHS;
6639
6631
if (m_BinOp (m_Value (LHS), m_Value (RHS)).match (I)) {
6640
- return OperationData (cast<BinaryOperator>(I)->getOpcode (), LHS, RHS,
6641
- RK_Arithmetic);
6632
+ return OperationData (cast<BinaryOperator>(I)->getOpcode (), RK_Arithmetic);
6642
6633
}
6643
6634
if (auto *Select = dyn_cast<SelectInst>(I)) {
6644
6635
// Look for a min/max pattern.
6645
6636
if (m_UMin (m_Value (LHS), m_Value (RHS)).match (Select)) {
6646
- return OperationData (Instruction::ICmp, LHS, RHS, RK_UMin);
6637
+ return OperationData (Instruction::ICmp, RK_UMin);
6647
6638
} else if (m_SMin (m_Value (LHS), m_Value (RHS)).match (Select)) {
6648
- return OperationData (Instruction::ICmp, LHS, RHS, RK_SMin);
6639
+ return OperationData (Instruction::ICmp, RK_SMin);
6649
6640
} else if (m_UMax (m_Value (LHS), m_Value (RHS)).match (Select)) {
6650
- return OperationData (Instruction::ICmp, LHS, RHS, RK_UMax);
6641
+ return OperationData (Instruction::ICmp, RK_UMax);
6651
6642
} else if (m_SMax (m_Value (LHS), m_Value (RHS)).match (Select)) {
6652
- return OperationData (Instruction::ICmp, LHS, RHS, RK_SMax);
6643
+ return OperationData (Instruction::ICmp, RK_SMax);
6653
6644
} else {
6654
6645
// Try harder: look for min/max pattern based on instructions producing
6655
6646
// same values such as: select ((cmp Inst1, Inst2), Inst1, Inst2).
@@ -6693,19 +6684,19 @@ class HorizontalReduction {
6693
6684
6694
6685
case CmpInst::ICMP_ULT:
6695
6686
case CmpInst::ICMP_ULE:
6696
- return OperationData (Instruction::ICmp, LHS, RHS, RK_UMin);
6687
+ return OperationData (Instruction::ICmp, RK_UMin);
6697
6688
6698
6689
case CmpInst::ICMP_SLT:
6699
6690
case CmpInst::ICMP_SLE:
6700
- return OperationData (Instruction::ICmp, LHS, RHS, RK_SMin);
6691
+ return OperationData (Instruction::ICmp, RK_SMin);
6701
6692
6702
6693
case CmpInst::ICMP_UGT:
6703
6694
case CmpInst::ICMP_UGE:
6704
- return OperationData (Instruction::ICmp, LHS, RHS, RK_UMax);
6695
+ return OperationData (Instruction::ICmp, RK_UMax);
6705
6696
6706
6697
case CmpInst::ICMP_SGT:
6707
6698
case CmpInst::ICMP_SGE:
6708
- return OperationData (Instruction::ICmp, LHS, RHS, RK_SMax);
6699
+ return OperationData (Instruction::ICmp, RK_SMax);
6709
6700
}
6710
6701
}
6711
6702
}
@@ -6726,13 +6717,13 @@ class HorizontalReduction {
6726
6717
// r *= v1 + v2 + v3 + v4
6727
6718
// In such a case start looking for a tree rooted in the first '+'.
6728
6719
if (Phi) {
6729
- if (ReductionData.getLHS () == Phi) {
6720
+ if (ReductionData.getLHS (B ) == Phi) {
6730
6721
Phi = nullptr ;
6731
- B = dyn_cast<Instruction>(ReductionData.getRHS ());
6722
+ B = dyn_cast<Instruction>(ReductionData.getRHS (B ));
6732
6723
ReductionData = getOperationData (B);
6733
- } else if (ReductionData.getRHS () == Phi) {
6724
+ } else if (ReductionData.getRHS (B ) == Phi) {
6734
6725
Phi = nullptr ;
6735
- B = dyn_cast<Instruction>(ReductionData.getLHS ());
6726
+ B = dyn_cast<Instruction>(ReductionData.getLHS (B ));
6736
6727
ReductionData = getOperationData (B);
6737
6728
}
6738
6729
}
@@ -6984,11 +6975,8 @@ class HorizontalReduction {
6984
6975
} else {
6985
6976
// Update the final value in the reduction.
6986
6977
Builder.SetCurrentDebugLocation (Loc);
6987
- OperationData VectReductionData (ReductionData.getOpcode (),
6988
- VectorizedTree, ReducedSubTree,
6989
- ReductionData.getKind ());
6990
- VectorizedTree =
6991
- VectReductionData.createOp (Builder, " op.rdx" , ReductionOps);
6978
+ VectorizedTree = ReductionData.createOp (
6979
+ Builder, VectorizedTree, ReducedSubTree, " op.rdx" , ReductionOps);
6992
6980
}
6993
6981
i += ReduxWidth;
6994
6982
ReduxWidth = PowerOf2Floor (NumReducedVals - i);
@@ -6999,19 +6987,15 @@ class HorizontalReduction {
6999
6987
for (; i < NumReducedVals; ++i) {
7000
6988
auto *I = cast<Instruction>(ReducedVals[i]);
7001
6989
Builder.SetCurrentDebugLocation (I->getDebugLoc ());
7002
- OperationData VectReductionData (ReductionData.getOpcode (),
7003
- VectorizedTree, I,
7004
- ReductionData.getKind ());
7005
- VectorizedTree = VectReductionData.createOp (Builder, " " , ReductionOps);
6990
+ VectorizedTree = ReductionData.createOp (Builder, VectorizedTree, I, " " ,
6991
+ ReductionOps);
7006
6992
}
7007
6993
for (auto &Pair : ExternallyUsedValues) {
7008
6994
// Add each externally used value to the final reduction.
7009
6995
for (auto *I : Pair.second ) {
7010
6996
Builder.SetCurrentDebugLocation (I->getDebugLoc ());
7011
- OperationData VectReductionData (ReductionData.getOpcode (),
7012
- VectorizedTree, Pair.first ,
7013
- ReductionData.getKind ());
7014
- VectorizedTree = VectReductionData.createOp (Builder, " op.extra" , I);
6997
+ VectorizedTree = ReductionData.createOp (Builder, VectorizedTree,
6998
+ Pair.first , " op.extra" , I);
7015
6999
}
7016
7000
}
7017
7001
@@ -7133,9 +7117,8 @@ class HorizontalReduction {
7133
7117
Builder.CreateShuffleVector (TmpVec, LeftMask, " rdx.shuf.l" );
7134
7118
Value *RightShuf =
7135
7119
Builder.CreateShuffleVector (TmpVec, RightMask, " rdx.shuf.r" );
7136
- OperationData VectReductionData (ReductionData.getOpcode (), LeftShuf,
7137
- RightShuf, ReductionData.getKind ());
7138
- TmpVec = VectReductionData.createOp (Builder, " op.rdx" , ReductionOps);
7120
+ TmpVec = ReductionData.createOp (Builder, LeftShuf, RightShuf, " op.rdx" ,
7121
+ ReductionOps);
7139
7122
}
7140
7123
7141
7124
// The result is in the first element of the vector.
0 commit comments