Skip to content

Commit 55eb93b

Browse files
authored
[RISCV] Remove RISCVISD::FP_EXTEND_BF16. (#106939)
I don't think we need this node. We can isel fp_extend directly. fp_extend to f64 requires two instructions, but we can emit them with an isel pattern. I have not removed RISCVISD::FP_ROUND_BF16 because f64->bf16 needs more work to fix the double rounding.
1 parent 38ae53d commit 55eb93b

File tree

4 files changed

+11
-32
lines changed

4 files changed

+11
-32
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
452452
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
453453
setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
454454
setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
455-
setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom);
456-
setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
457455
setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
458456
setOperationAction(ISD::SELECT_CC, MVT::bf16, Expand);
459457
setOperationAction(ISD::BR_CC, MVT::bf16, Expand);
@@ -6500,18 +6498,6 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
65006498
return SplitVectorOp(Op, DAG);
65016499
return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
65026500
case ISD::FP_EXTEND: {
6503-
SDLoc DL(Op);
6504-
EVT VT = Op.getValueType();
6505-
SDValue Op0 = Op.getOperand(0);
6506-
EVT Op0VT = Op0.getValueType();
6507-
if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin())
6508-
return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
6509-
if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) {
6510-
SDValue FloatVal =
6511-
DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0);
6512-
return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal);
6513-
}
6514-
65156501
if (!Op.getValueType().isVector())
65166502
return Op;
65176503
return lowerVectorFPExtendOrRoundLike(Op, DAG);
@@ -20463,7 +20449,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2046320449
NODE_NAME_CASE(STRICT_FCVT_W_RV64)
2046420450
NODE_NAME_CASE(STRICT_FCVT_WU_RV64)
2046520451
NODE_NAME_CASE(FP_ROUND_BF16)
20466-
NODE_NAME_CASE(FP_EXTEND_BF16)
2046720452
NODE_NAME_CASE(FROUND)
2046820453
NODE_NAME_CASE(FCLASS)
2046920454
NODE_NAME_CASE(FSGNJX)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ enum NodeType : unsigned {
117117
FCVT_WU_RV64,
118118

119119
FP_ROUND_BF16,
120-
FP_EXTEND_BF16,
121120

122121
// Rounds an FP value to its corresponding integer in the same FP format.
123122
// First operand is the value to round, the second operand is the largest

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -677,8 +677,7 @@ multiclass VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM<SDNode op,
677677
VPatWidenBinaryFPSDNode_WV_WF_RM<op, instruction_name>;
678678

679679
multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
680-
list <VTypeInfoToWide> vtiToWtis,
681-
PatFrags extop> {
680+
list <VTypeInfoToWide> vtiToWtis> {
682681
foreach vtiToWti = vtiToWtis in {
683682
defvar vti = vtiToWti.Vti;
684683
defvar wti = vtiToWti.Wti;
@@ -702,7 +701,7 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
702701
FRM_DYN,
703702
vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
704703
def : Pat<(fma (wti.Vector (SplatFPOp
705-
(extop (vti.Scalar vti.ScalarRegClass:$rs1)))),
704+
(fpext_oneuse (vti.Scalar vti.ScalarRegClass:$rs1)))),
706705
(wti.Vector (riscv_fpextend_vl_oneuse
707706
(vti.Vector vti.RegClass:$rs2),
708707
(vti.Mask true_mask), (XLenVT srcvalue))),
@@ -1290,11 +1289,9 @@ foreach fvti = AllFloatVectors in {
12901289

12911290
// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
12921291
defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACC",
1293-
AllWidenableFloatVectors,
1294-
fpext_oneuse>;
1292+
AllWidenableFloatVectors>;
12951293
defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACCBF16",
1296-
AllWidenableBFloatToFloatVectors,
1297-
riscv_fpextend_bf16_oneuse>;
1294+
AllWidenableBFloatToFloatVectors>;
12981295
defm : VPatWidenFPNegMulAccSDNode_VV_VF_RM<"PseudoVFWNMACC">;
12991296
defm : VPatWidenFPMulSacSDNode_VV_VF_RM<"PseudoVFWMSAC">;
13001297
defm : VPatWidenFPNegMulSacSDNode_VV_VF_RM<"PseudoVFWNMSAC">;

llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,9 @@
1919

2020
def SDT_RISCVFP_ROUND_BF16
2121
: SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, f32>]>;
22-
def SDT_RISCVFP_EXTEND_BF16
23-
: SDTypeProfile<1, 1, [SDTCisVT<0, f32>, SDTCisVT<1, bf16>]>;
2422

2523
def riscv_fpround_bf16
2624
: SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>;
27-
def riscv_fpextend_bf16
28-
: SDNode<"RISCVISD::FP_EXTEND_BF16", SDT_RISCVFP_EXTEND_BF16>;
29-
def riscv_fpextend_bf16_oneuse : PatFrag<(ops node:$A),
30-
(riscv_fpextend_bf16 node:$A), [{
31-
return N->hasOneUse();
32-
}]>;
3325

3426
//===----------------------------------------------------------------------===//
3527
// Instructions
@@ -57,7 +49,7 @@ def : StPat<store, FSH, FPR16, bf16>;
5749
// f32 -> bf16, bf16 -> f32
5850
def : Pat<(bf16 (riscv_fpround_bf16 FPR32:$rs1)),
5951
(FCVT_BF16_S FPR32:$rs1, FRM_DYN)>;
60-
def : Pat<(riscv_fpextend_bf16 (bf16 FPR16:$rs1)),
52+
def : Pat<(fpextend (bf16 FPR16:$rs1)),
6153
(FCVT_S_BF16 FPR16:$rs1, FRM_DYN)>;
6254

6355
// Moves (no conversion)
@@ -87,3 +79,9 @@ def : Pat<(i64 (any_fp_to_uint (bf16 FPR16:$rs1))), (FCVT_LU_S (FCVT_S_BF16 $rs1
8779
def : Pat<(bf16 (any_sint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_L $rs1, FRM_DYN), FRM_DYN)>;
8880
def : Pat<(bf16 (any_uint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_LU $rs1, FRM_DYN), FRM_DYN)>;
8981
}
82+
83+
let Predicates = [HasStdExtZfbfmin, HasStdExtD] in {
84+
// bf16 -> f64
85+
def : Pat<(fpextend (bf16 FPR16:$rs1)),
86+
(FCVT_D_S (FCVT_S_BF16 FPR16:$rs1, FRM_DYN), FRM_RNE)>;
87+
}

0 commit comments

Comments
 (0)