Skip to content

Commit e33e623

Browse files
authored
[NVPTX] Consistently check fast-math flags when lowering div (llvm#136890)
When choosing the `div.*` variant during ISel, check the instruction-level fast-math flags.
1 parent e3950a0 commit e33e623

File tree

8 files changed

+245
-95
lines changed

8 files changed

+245
-95
lines changed

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,14 @@ enum PrmtMode {
253253
RC16,
254254
};
255255
}
256-
}
256+
257+
enum class DivPrecisionLevel : unsigned {
258+
Approx = 0,
259+
Full = 1,
260+
IEEE754 = 2,
261+
};
262+
263+
} // namespace NVPTX
257264
void initializeNVPTXDAGToDAGISelLegacyPass(PassRegistry &);
258265
} // namespace llvm
259266

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) {
6666
return SelectionDAGISel::runOnMachineFunction(MF);
6767
}
6868

69-
int NVPTXDAGToDAGISel::getDivF32Level() const {
70-
return Subtarget->getTargetLowering()->getDivF32Level();
69+
NVPTX::DivPrecisionLevel
70+
NVPTXDAGToDAGISel::getDivF32Level(const SDNode *N) const {
71+
return Subtarget->getTargetLowering()->getDivF32Level(*MF, *N);
7172
}
7273

7374
bool NVPTXDAGToDAGISel::usePrecSqrtF32() const {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
4343
// If true, generate mul.wide from sext and mul
4444
bool doMulWide;
4545

46-
int getDivF32Level() const;
46+
NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const;
4747
bool usePrecSqrtF32() const;
4848
bool useF32FTZ() const;
4949
bool allowFMA() const;

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,16 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88-
static cl::opt<int> UsePrecDivF32(
88+
static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
8989
"nvptx-prec-divf32", cl::Hidden,
9090
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
9191
" IEEE Compliant F32 div.rnd if available."),
92-
cl::init(2));
92+
cl::values(clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0",
93+
"Use div.approx"),
94+
clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
95+
clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
96+
"Use IEEE Compliant F32 div.rnd if available")),
97+
cl::init(NVPTX::DivPrecisionLevel::IEEE754));
9398

9499
static cl::opt<bool> UsePrecSqrtF32(
95100
"nvptx-prec-sqrtf32", cl::Hidden,
@@ -109,17 +114,22 @@ static cl::opt<bool> ForceMinByValParamAlign(
109114
" params of device functions."),
110115
cl::init(false));
111116

112-
int NVPTXTargetLowering::getDivF32Level() const {
113-
if (UsePrecDivF32.getNumOccurrences() > 0) {
114-
// If nvptx-prec-div32=N is used on the command-line, always honor it
117+
NVPTX::DivPrecisionLevel
118+
NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
119+
const SDNode &N) const {
120+
// If nvptx-prec-div32=N is used on the command-line, always honor it
121+
if (UsePrecDivF32.getNumOccurrences() > 0)
115122
return UsePrecDivF32;
116-
} else {
117-
// Otherwise, use div.approx if fast math is enabled
118-
if (getTargetMachine().Options.UnsafeFPMath)
119-
return 0;
120-
else
121-
return 2;
122-
}
123+
124+
// Otherwise, use div.approx if fast math is enabled
125+
if (allowUnsafeFPMath(MF))
126+
return NVPTX::DivPrecisionLevel::Approx;
127+
128+
const SDNodeFlags Flags = N.getFlags();
129+
if (Flags.hasApproximateFuncs())
130+
return NVPTX::DivPrecisionLevel::Approx;
131+
132+
return NVPTX::DivPrecisionLevel::IEEE754;
123133
}
124134

125135
bool NVPTXTargetLowering::usePrecSqrtF32() const {
@@ -4975,7 +4985,7 @@ bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
49754985
return allowUnsafeFPMath(MF);
49764986
}
49774987

4978-
bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
4988+
bool NVPTXTargetLowering::allowUnsafeFPMath(const MachineFunction &MF) const {
49794989
// Honor TargetOptions flags that explicitly say unsafe math is okay.
49804990
if (MF.getTarget().Options.UnsafeFPMath)
49814991
return true;

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,8 @@ class NVPTXTargetLowering : public TargetLowering {
216216

217217
// Get the degree of precision we want from 32-bit floating point division
218218
// operations.
219-
//
220-
// 0 - Use ptx div.approx
221-
// 1 - Use ptx.div.full (approximate, but less so than div.approx)
222-
// 2 - Use IEEE-compliant div instructions, if available.
223-
int getDivF32Level() const;
219+
NVPTX::DivPrecisionLevel getDivF32Level(const MachineFunction &MF,
220+
const SDNode &N) const;
224221

225222
// Get whether we should use a precise or approximate 32-bit floating point
226223
// sqrt instruction.
@@ -237,7 +234,7 @@ class NVPTXTargetLowering : public TargetLowering {
237234
unsigned combineRepeatedFPDivisors() const override { return 2; }
238235

239236
bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const;
240-
bool allowUnsafeFPMath(MachineFunction &MF) const;
237+
bool allowUnsafeFPMath(const MachineFunction &MF) const;
241238

242239
bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
243240
EVT) const override {

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 61 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
151151

152152
def doMulWide : Predicate<"doMulWide">;
153153

154-
def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
155-
def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;
156-
157154
def do_SQRTF32_APPROX : Predicate<"!usePrecSqrtF32()">;
158155
def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">;
159156

@@ -1119,26 +1116,19 @@ def INEG64 :
11191116
//-----------------------------------
11201117

11211118
// Constant 1.0f
1122-
def FloatConst1 : PatLeaf<(fpimm), [{
1123-
return &N->getValueAPF().getSemantics() == &llvm::APFloat::IEEEsingle() &&
1124-
N->getValueAPF().convertToFloat() == 1.0f;
1119+
def f32imm_1 : FPImmLeaf<f32, [{
1120+
return &Imm.getSemantics() == &llvm::APFloat::IEEEsingle() &&
1121+
Imm.convertToFloat() == 1.0f;
11251122
}]>;
11261123
// Constant 1.0 (double)
1127-
def DoubleConst1 : PatLeaf<(fpimm), [{
1128-
return &N->getValueAPF().getSemantics() == &llvm::APFloat::IEEEdouble() &&
1129-
N->getValueAPF().convertToDouble() == 1.0;
1124+
def f64imm_1 : FPImmLeaf<f64, [{
1125+
return &Imm.getSemantics() == &llvm::APFloat::IEEEdouble() &&
1126+
Imm.convertToDouble() == 1.0;
11301127
}]>;
11311128
// Constant -1.0 (double)
1132-
def DoubleConstNeg1 : PatLeaf<(fpimm), [{
1133-
return &N->getValueAPF().getSemantics() == &llvm::APFloat::IEEEdouble() &&
1134-
N->getValueAPF().convertToDouble() == -1.0;
1135-
}]>;
1136-
1137-
1138-
// Constant -X -> X (double)
1139-
def NegDoubleConst : SDNodeXForm<fpimm, [{
1140-
return CurDAG->getTargetConstantFP(-(N->getValueAPF()),
1141-
SDLoc(N), MVT::f64);
1129+
def f64imm_neg1 : FPImmLeaf<f64, [{
1130+
return &Imm.getSemantics() == &llvm::APFloat::IEEEdouble() &&
1131+
Imm.convertToDouble() == -1.0;
11421132
}]>;
11431133

11441134
defm FADD : F3_fma_component<"add", fadd>;
@@ -1189,11 +1179,11 @@ def BFNEG16x2 : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, Int32Regs, True>;
11891179
//
11901180
// F64 division
11911181
//
1192-
def FDIV641r :
1182+
def FRCP64r :
11931183
NVPTXInst<(outs Float64Regs:$dst),
1194-
(ins f64imm:$a, Float64Regs:$b),
1184+
(ins Float64Regs:$b),
11951185
"rcp.rn.f64 \t$dst, $b;",
1196-
[(set f64:$dst, (fdiv DoubleConst1:$a, f64:$b))]>;
1186+
[(set f64:$dst, (fdiv f64imm_1, f64:$b))]>;
11971187
def FDIV64rr :
11981188
NVPTXInst<(outs Float64Regs:$dst),
11991189
(ins Float64Regs:$a, Float64Regs:$b),
@@ -1207,109 +1197,114 @@ def FDIV64ri :
12071197

12081198
// fdiv will be converted to rcp
12091199
// fneg (fdiv 1.0, X) => fneg (rcp.rn X)
1210-
def : Pat<(fdiv DoubleConstNeg1:$a, f64:$b),
1211-
(FNEGf64 (FDIV641r (NegDoubleConst node:$a), $b))>;
1200+
def : Pat<(fdiv f64imm_neg1, f64:$b),
1201+
(FNEGf64 (FRCP64r $b))>;
12121202

12131203
//
12141204
// F32 Approximate reciprocal
12151205
//
1216-
def FDIV321r_ftz :
1206+
1207+
def fdiv_approx : PatFrag<(ops node:$a, node:$b),
1208+
(fdiv node:$a, node:$b), [{
1209+
return getDivF32Level(N) == NVPTX::DivPrecisionLevel::Approx;
1210+
}]>;
1211+
1212+
1213+
def FRCP32_approx_r_ftz :
12171214
NVPTXInst<(outs Float32Regs:$dst),
1218-
(ins f32imm:$a, Float32Regs:$b),
1215+
(ins Float32Regs:$b),
12191216
"rcp.approx.ftz.f32 \t$dst, $b;",
1220-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1221-
Requires<[do_DIVF32_APPROX, doF32FTZ]>;
1222-
def FDIV321r :
1217+
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>,
1218+
Requires<[doF32FTZ]>;
1219+
def FRCP32_approx_r :
12231220
NVPTXInst<(outs Float32Regs:$dst),
1224-
(ins f32imm:$a, Float32Regs:$b),
1221+
(ins Float32Regs:$b),
12251222
"rcp.approx.f32 \t$dst, $b;",
1226-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1227-
Requires<[do_DIVF32_APPROX]>;
1223+
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>;
1224+
12281225
//
12291226
// F32 Approximate division
12301227
//
12311228
def FDIV32approxrr_ftz :
12321229
NVPTXInst<(outs Float32Regs:$dst),
12331230
(ins Float32Regs:$a, Float32Regs:$b),
12341231
"div.approx.ftz.f32 \t$dst, $a, $b;",
1235-
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
1236-
Requires<[do_DIVF32_APPROX, doF32FTZ]>;
1232+
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>,
1233+
Requires<[doF32FTZ]>;
12371234
def FDIV32approxri_ftz :
12381235
NVPTXInst<(outs Float32Regs:$dst),
12391236
(ins Float32Regs:$a, f32imm:$b),
12401237
"div.approx.ftz.f32 \t$dst, $a, $b;",
1241-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
1242-
Requires<[do_DIVF32_APPROX, doF32FTZ]>;
1238+
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>,
1239+
Requires<[doF32FTZ]>;
12431240
def FDIV32approxrr :
12441241
NVPTXInst<(outs Float32Regs:$dst),
12451242
(ins Float32Regs:$a, Float32Regs:$b),
12461243
"div.approx.f32 \t$dst, $a, $b;",
1247-
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
1248-
Requires<[do_DIVF32_APPROX]>;
1244+
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>;
12491245
def FDIV32approxri :
12501246
NVPTXInst<(outs Float32Regs:$dst),
12511247
(ins Float32Regs:$a, f32imm:$b),
12521248
"div.approx.f32 \t$dst, $a, $b;",
1253-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
1254-
Requires<[do_DIVF32_APPROX]>;
1249+
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>;
12551250
//
12561251
// F32 Semi-accurate reciprocal
12571252
//
12581253
// rcp.approx gives the same result as div.full(1.0f, a) and is faster.
12591254
//
1260-
def FDIV321r_approx_ftz :
1261-
NVPTXInst<(outs Float32Regs:$dst),
1262-
(ins f32imm:$a, Float32Regs:$b),
1263-
"rcp.approx.ftz.f32 \t$dst, $b;",
1264-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1265-
Requires<[do_DIVF32_FULL, doF32FTZ]>;
1266-
def FDIV321r_approx :
1267-
NVPTXInst<(outs Float32Regs:$dst),
1268-
(ins f32imm:$a, Float32Regs:$b),
1269-
"rcp.approx.f32 \t$dst, $b;",
1270-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1271-
Requires<[do_DIVF32_FULL]>;
1255+
1256+
def fdiv_full : PatFrag<(ops node:$a, node:$b),
1257+
(fdiv node:$a, node:$b), [{
1258+
return getDivF32Level(N) == NVPTX::DivPrecisionLevel::Full;
1259+
}]>;
1260+
1261+
1262+
def : Pat<(fdiv_full f32imm_1, f32:$b),
1263+
(FRCP32_approx_r_ftz $b)>,
1264+
Requires<[doF32FTZ]>;
1265+
1266+
def : Pat<(fdiv_full f32imm_1, f32:$b),
1267+
(FRCP32_approx_r $b)>;
1268+
12721269
//
12731270
// F32 Semi-accurate division
12741271
//
12751272
def FDIV32rr_ftz :
12761273
NVPTXInst<(outs Float32Regs:$dst),
12771274
(ins Float32Regs:$a, Float32Regs:$b),
12781275
"div.full.ftz.f32 \t$dst, $a, $b;",
1279-
[(set f32:$dst, (fdiv Float32Regs:$a, f32:$b))]>,
1280-
Requires<[do_DIVF32_FULL, doF32FTZ]>;
1276+
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>,
1277+
Requires<[doF32FTZ]>;
12811278
def FDIV32ri_ftz :
12821279
NVPTXInst<(outs Float32Regs:$dst),
12831280
(ins Float32Regs:$a, f32imm:$b),
12841281
"div.full.ftz.f32 \t$dst, $a, $b;",
1285-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
1286-
Requires<[do_DIVF32_FULL, doF32FTZ]>;
1282+
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>,
1283+
Requires<[doF32FTZ]>;
12871284
def FDIV32rr :
12881285
NVPTXInst<(outs Float32Regs:$dst),
12891286
(ins Float32Regs:$a, Float32Regs:$b),
12901287
"div.full.f32 \t$dst, $a, $b;",
1291-
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
1292-
Requires<[do_DIVF32_FULL]>;
1288+
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>;
12931289
def FDIV32ri :
12941290
NVPTXInst<(outs Float32Regs:$dst),
12951291
(ins Float32Regs:$a, f32imm:$b),
12961292
"div.full.f32 \t$dst, $a, $b;",
1297-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
1298-
Requires<[do_DIVF32_FULL]>;
1293+
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>;
12991294
//
13001295
// F32 Accurate reciprocal
13011296
//
1302-
def FDIV321r_prec_ftz :
1297+
def FRCP32r_prec_ftz :
13031298
NVPTXInst<(outs Float32Regs:$dst),
1304-
(ins f32imm:$a, Float32Regs:$b),
1299+
(ins Float32Regs:$b),
13051300
"rcp.rn.ftz.f32 \t$dst, $b;",
1306-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
1301+
[(set f32:$dst, (fdiv f32imm_1, f32:$b))]>,
13071302
Requires<[doF32FTZ]>;
1308-
def FDIV321r_prec :
1303+
def FRCP32r_prec :
13091304
NVPTXInst<(outs Float32Regs:$dst),
1310-
(ins f32imm:$a, Float32Regs:$b),
1305+
(ins Float32Regs:$b),
13111306
"rcp.rn.f32 \t$dst, $b;",
1312-
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>;
1307+
[(set f32:$dst, (fdiv f32imm_1, f32:$b))]>;
13131308
//
13141309
// F32 Accurate division
13151310
//

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,24 +1615,24 @@ def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
16151615
F64RT, F64RT, int_nvvm_rsqrt_approx_d>;
16161616

16171617
// 1.0f / sqrt_approx -> rsqrt_approx
1618-
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f f32:$a)),
1618+
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_f f32:$a)),
16191619
(INT_NVVM_RSQRT_APPROX_F $a)>,
16201620
Requires<[doRsqrtOpt]>;
1621-
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f f32:$a)),
1621+
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_ftz_f f32:$a)),
16221622
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
16231623
Requires<[doRsqrtOpt]>;
16241624
// same for int_nvvm_sqrt_f when non-precision sqrt is requested
1625-
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f f32:$a)),
1625+
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
16261626
(INT_NVVM_RSQRT_APPROX_F $a)>,
16271627
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1628-
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f f32:$a)),
1628+
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
16291629
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
16301630
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
16311631

1632-
def: Pat<(fdiv FloatConst1, (fsqrt f32:$a)),
1632+
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
16331633
(INT_NVVM_RSQRT_APPROX_F $a)>,
16341634
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1635-
def: Pat<(fdiv FloatConst1, (fsqrt f32:$a)),
1635+
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
16361636
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
16371637
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
16381638
//

0 commit comments

Comments
 (0)