@@ -166,6 +166,13 @@ class VectorLegalizer {
166
166
// / truncated back to the original type.
167
167
void PromoteFP_TO_INT (SDNode *Node, SmallVectorImpl<SDValue> &Results);
168
168
169
+ // / Implements vector reduce operation promotion.
170
+ // /
171
+ // / All vector operands are promoted to a vector type with larger element
172
+ // / type, and the start value is promoted to a larger scalar type. Then the
173
+ // / result is truncated back to the original scalar type.
174
+ void PromoteReduction (SDNode *Node, SmallVectorImpl<SDValue> &Results);
175
+
169
176
public:
170
177
VectorLegalizer (SelectionDAG& dag) :
171
178
DAG (dag), TLI(dag.getTargetLoweringInfo()) {}
@@ -551,6 +558,50 @@ bool VectorLegalizer::LowerOperationWrapper(SDNode *Node,
551
558
return true ;
552
559
}
553
560
561
+ void VectorLegalizer::PromoteReduction (SDNode *Node,
562
+ SmallVectorImpl<SDValue> &Results) {
563
+ MVT VecVT = Node->getOperand (1 ).getSimpleValueType ();
564
+ MVT NewVecVT = TLI.getTypeToPromoteTo (Node->getOpcode (), VecVT);
565
+ MVT ScalarVT = Node->getSimpleValueType (0 );
566
+ MVT NewScalarVT = NewVecVT.getVectorElementType ();
567
+
568
+ SDLoc DL (Node);
569
+ SmallVector<SDValue, 4 > Operands (Node->getNumOperands ());
570
+
571
+ // promote the initial value.
572
+ if (Node->getOperand (0 ).getValueType ().isFloatingPoint ())
573
+ Operands[0 ] =
574
+ DAG.getNode (ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand (0 ));
575
+ else
576
+ Operands[0 ] =
577
+ DAG.getNode (ISD::ANY_EXTEND, DL, NewScalarVT, Node->getOperand (0 ));
578
+
579
+ for (unsigned j = 1 ; j != Node->getNumOperands (); ++j)
580
+ if (Node->getOperand (j).getValueType ().isVector () &&
581
+ !(ISD::isVPOpcode (Node->getOpcode ()) &&
582
+ ISD::getVPMaskIdx (Node->getOpcode ()) == j)) // Skip mask operand.
583
+ // promote the vector operand.
584
+ if (Node->getOperand (j).getValueType ().isFloatingPoint ())
585
+ Operands[j] =
586
+ DAG.getNode (ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand (j));
587
+ else
588
+ Operands[j] =
589
+ DAG.getNode (ISD::ANY_EXTEND, DL, NewVecVT, Node->getOperand (j));
590
+ else
591
+ Operands[j] = Node->getOperand (j); // Skip VL operand.
592
+
593
+ SDValue Res = DAG.getNode (Node->getOpcode (), DL, NewScalarVT, Operands,
594
+ Node->getFlags ());
595
+
596
+ if (ScalarVT.isFloatingPoint ())
597
+ Res = DAG.getNode (ISD::FP_ROUND, DL, ScalarVT, Res,
598
+ DAG.getIntPtrConstant (0 , DL, /* isTarget=*/ true ));
599
+ else
600
+ Res = DAG.getNode (ISD::TRUNCATE, DL, ScalarVT, Res);
601
+
602
+ Results.push_back (Res);
603
+ }
604
+
554
605
void VectorLegalizer::Promote (SDNode *Node, SmallVectorImpl<SDValue> &Results) {
555
606
// For a few operations there is a specific concept for promotion based on
556
607
// the operand's type.
@@ -569,6 +620,23 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
569
620
// Promote the operation by extending the operand.
570
621
PromoteFP_TO_INT (Node, Results);
571
622
return ;
623
+ case ISD::VP_REDUCE_ADD:
624
+ case ISD::VP_REDUCE_MUL:
625
+ case ISD::VP_REDUCE_AND:
626
+ case ISD::VP_REDUCE_OR:
627
+ case ISD::VP_REDUCE_XOR:
628
+ case ISD::VP_REDUCE_SMAX:
629
+ case ISD::VP_REDUCE_SMIN:
630
+ case ISD::VP_REDUCE_UMAX:
631
+ case ISD::VP_REDUCE_UMIN:
632
+ case ISD::VP_REDUCE_FADD:
633
+ case ISD::VP_REDUCE_FMUL:
634
+ case ISD::VP_REDUCE_FMAX:
635
+ case ISD::VP_REDUCE_FMIN:
636
+ case ISD::VP_REDUCE_SEQ_FADD:
637
+ // Promote the operation by extending the operand.
638
+ PromoteReduction (Node, Results);
639
+ return ;
572
640
case ISD::FP_ROUND:
573
641
case ISD::FP_EXTEND:
574
642
// These operations are used to do promotion so they can't be promoted
@@ -589,7 +657,10 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
589
657
SmallVector<SDValue, 4 > Operands (Node->getNumOperands ());
590
658
591
659
for (unsigned j = 0 ; j != Node->getNumOperands (); ++j) {
592
- if (Node->getOperand (j).getValueType ().isVector ())
660
+ // Do not promote the mask operand of a VP OP.
661
+ bool SkipPromote = ISD::isVPOpcode (Node->getOpcode ()) &&
662
+ ISD::getVPMaskIdx (Node->getOpcode ()) == j;
663
+ if (Node->getOperand (j).getValueType ().isVector () && !SkipPromote)
593
664
if (Node->getOperand (j)
594
665
.getValueType ()
595
666
.getVectorElementType ()
0 commit comments