Skip to content

Commit 879e801

Browse files
committed
[RISCV] Apply promotion for f16 vector ops when only have zvfhmin
For most fp16 vector ops, we could promote it to fp32 vector when zvfhmin is enable but zvfh is not. But for nxv32f16, we need to split it first since nxv32f32 is not a valid MVT. Reviewed By: michaelmaitland Differential Revision: https://reviews.llvm.org/D153848
1 parent 7599035 commit 879e801

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+26976
-5160
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2555,6 +2555,13 @@ class TargetLoweringBase {
25552555
setOperationAction(Opc, OrigVT, Promote);
25562556
AddPromotedToType(Opc, OrigVT, DestVT);
25572557
}
2558+
void setOperationPromotedToType(ArrayRef<unsigned> Ops, MVT OrigVT,
2559+
MVT DestVT) {
2560+
for (auto Op : Ops) {
2561+
setOperationAction(Op, OrigVT, Promote);
2562+
AddPromotedToType(Op, OrigVT, DestVT);
2563+
}
2564+
}
25582565

25592566
/// Targets should invoke this method for each target independent node that
25602567
/// they want to provide a custom DAG combiner for by implementing the

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5459,6 +5459,23 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
54595459
Results.push_back(NewAtomic.getValue(1));
54605460
break;
54615461
}
5462+
case ISD::SPLAT_VECTOR: {
5463+
SDValue Scalar = Node->getOperand(0);
5464+
MVT ScalarType = Scalar.getSimpleValueType();
5465+
MVT NewScalarType = NVT.getVectorElementType();
5466+
if (ScalarType.isInteger()) {
5467+
Tmp1 = DAG.getNode(ISD::ANY_EXTEND, dl, NewScalarType, Scalar);
5468+
Tmp2 = DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1);
5469+
Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2));
5470+
break;
5471+
}
5472+
Tmp1 = DAG.getNode(ISD::FP_EXTEND, dl, NewScalarType, Scalar);
5473+
Tmp2 = DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1);
5474+
Results.push_back(
5475+
DAG.getNode(ISD::FP_ROUND, dl, OVT, Tmp2,
5476+
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)));
5477+
break;
5478+
}
54625479
}
54635480

54645481
// Replace the original node with the legalized result.

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,13 @@ class VectorLegalizer {
166166
/// truncated back to the original type.
167167
void PromoteFP_TO_INT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
168168

169+
/// Implements vector reduce operation promotion.
170+
///
171+
/// All vector operands are promoted to a vector type with larger element
172+
/// type, and the start value is promoted to a larger scalar type. Then the
173+
/// result is truncated back to the original scalar type.
174+
void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
175+
169176
public:
170177
VectorLegalizer(SelectionDAG& dag) :
171178
DAG(dag), TLI(dag.getTargetLoweringInfo()) {}
@@ -551,6 +558,50 @@ bool VectorLegalizer::LowerOperationWrapper(SDNode *Node,
551558
return true;
552559
}
553560

561+
void VectorLegalizer::PromoteReduction(SDNode *Node,
562+
SmallVectorImpl<SDValue> &Results) {
563+
MVT VecVT = Node->getOperand(1).getSimpleValueType();
564+
MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
565+
MVT ScalarVT = Node->getSimpleValueType(0);
566+
MVT NewScalarVT = NewVecVT.getVectorElementType();
567+
568+
SDLoc DL(Node);
569+
SmallVector<SDValue, 4> Operands(Node->getNumOperands());
570+
571+
// promote the initial value.
572+
if (Node->getOperand(0).getValueType().isFloatingPoint())
573+
Operands[0] =
574+
DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
575+
else
576+
Operands[0] =
577+
DAG.getNode(ISD::ANY_EXTEND, DL, NewScalarVT, Node->getOperand(0));
578+
579+
for (unsigned j = 1; j != Node->getNumOperands(); ++j)
580+
if (Node->getOperand(j).getValueType().isVector() &&
581+
!(ISD::isVPOpcode(Node->getOpcode()) &&
582+
ISD::getVPMaskIdx(Node->getOpcode()) == j)) // Skip mask operand.
583+
// promote the vector operand.
584+
if (Node->getOperand(j).getValueType().isFloatingPoint())
585+
Operands[j] =
586+
DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
587+
else
588+
Operands[j] =
589+
DAG.getNode(ISD::ANY_EXTEND, DL, NewVecVT, Node->getOperand(j));
590+
else
591+
Operands[j] = Node->getOperand(j); // Skip VL operand.
592+
593+
SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
594+
Node->getFlags());
595+
596+
if (ScalarVT.isFloatingPoint())
597+
Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
598+
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
599+
else
600+
Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
601+
602+
Results.push_back(Res);
603+
}
604+
554605
void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
555606
// For a few operations there is a specific concept for promotion based on
556607
// the operand's type.
@@ -569,6 +620,23 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
569620
// Promote the operation by extending the operand.
570621
PromoteFP_TO_INT(Node, Results);
571622
return;
623+
case ISD::VP_REDUCE_ADD:
624+
case ISD::VP_REDUCE_MUL:
625+
case ISD::VP_REDUCE_AND:
626+
case ISD::VP_REDUCE_OR:
627+
case ISD::VP_REDUCE_XOR:
628+
case ISD::VP_REDUCE_SMAX:
629+
case ISD::VP_REDUCE_SMIN:
630+
case ISD::VP_REDUCE_UMAX:
631+
case ISD::VP_REDUCE_UMIN:
632+
case ISD::VP_REDUCE_FADD:
633+
case ISD::VP_REDUCE_FMUL:
634+
case ISD::VP_REDUCE_FMAX:
635+
case ISD::VP_REDUCE_FMIN:
636+
case ISD::VP_REDUCE_SEQ_FADD:
637+
// Promote the operation by extending the operand.
638+
PromoteReduction(Node, Results);
639+
return;
572640
case ISD::FP_ROUND:
573641
case ISD::FP_EXTEND:
574642
// These operations are used to do promotion so they can't be promoted
@@ -589,7 +657,10 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
589657
SmallVector<SDValue, 4> Operands(Node->getNumOperands());
590658

591659
for (unsigned j = 0; j != Node->getNumOperands(); ++j) {
592-
if (Node->getOperand(j).getValueType().isVector())
660+
// Do not promote the mask operand of a VP OP.
661+
bool SkipPromote = ISD::isVPOpcode(Node->getOpcode()) &&
662+
ISD::getVPMaskIdx(Node->getOpcode()) == j;
663+
if (Node->getOperand(j).getValueType().isVector() && !SkipPromote)
593664
if (Node->getOperand(j)
594665
.getValueType()
595666
.getVectorElementType()

0 commit comments

Comments
 (0)