Skip to content

Commit b667e9c

Browse files
committed
[X86][BF16] Lower FP_ROUND for vector types under AVX512BF16
Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D158952
1 parent 23fef2c commit b667e9c

File tree

5 files changed

+660
-252
lines changed

5 files changed

+660
-252
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,8 +2237,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
22372237

22382238
if (!Subtarget.useSoftFloat() &&
22392239
(Subtarget.hasAVXNECONVERT() || Subtarget.hasBF16())) {
2240-
addRegisterClass(MVT::v8bf16, &X86::VR128XRegClass);
2241-
addRegisterClass(MVT::v16bf16, &X86::VR256XRegClass);
2240+
addRegisterClass(MVT::v8bf16, Subtarget.hasAVX512() ? &X86::VR128XRegClass
2241+
: &X86::VR128RegClass);
2242+
addRegisterClass(MVT::v16bf16, Subtarget.hasAVX512() ? &X86::VR256XRegClass
2243+
: &X86::VR256RegClass);
22422244
// We set the type action of bf16 to TypeSoftPromoteHalf, but we don't
22432245
// provide the method to promote BUILD_VECTOR and INSERT_VECTOR_ELT.
22442246
// Set the operation action Custom to do the customization later.
@@ -2253,6 +2255,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
22532255
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
22542256
setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
22552257
}
2258+
setOperationAction(ISD::FP_ROUND, MVT::v8bf16, Custom);
22562259
addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
22572260
}
22582261

@@ -2264,6 +2267,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
22642267
setOperationAction(ISD::FMUL, MVT::v32bf16, Expand);
22652268
setOperationAction(ISD::FDIV, MVT::v32bf16, Expand);
22662269
setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
2270+
setOperationAction(ISD::FP_ROUND, MVT::v16bf16, Custom);
22672271
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
22682272
}
22692273

@@ -21278,6 +21282,12 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
2127821282
return Res;
2127921283
}
2128021284

21285+
if (VT.getScalarType() == MVT::bf16) {
21286+
if (SVT.getScalarType() == MVT::f32 && isTypeLegal(VT))
21287+
return Op;
21288+
return SDValue();
21289+
}
21290+
2128121291
if (VT.getScalarType() == MVT::f16 && !Subtarget.hasFP16()) {
2128221292
if (!Subtarget.hasF16C() || SVT.getScalarType() != MVT::f32)
2128321293
return SDValue();

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12976,6 +12976,11 @@ let Predicates = [HasBF16, HasVLX] in {
1297612976
def : Pat<(v16bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
1297712977
(VPBROADCASTWZ256rr VR128X:$src)>;
1297812978

12979+
def : Pat<(v8bf16 (X86vfpround (v8f32 VR256X:$src))),
12980+
(VCVTNEPS2BF16Z256rr VR256X:$src)>;
12981+
def : Pat<(v8bf16 (X86vfpround (loadv8f32 addr:$src))),
12982+
(VCVTNEPS2BF16Z256rm addr:$src)>;
12983+
1297912984
// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
1298012985
}
1298112986

@@ -12985,6 +12990,11 @@ let Predicates = [HasBF16] in {
1298512990

1298612991
def : Pat<(v32bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
1298712992
(VPBROADCASTWZrr VR128X:$src)>;
12993+
12994+
def : Pat<(v16bf16 (X86vfpround (v16f32 VR512:$src))),
12995+
(VCVTNEPS2BF16Zrr VR512:$src)>;
12996+
def : Pat<(v16bf16 (X86vfpround (loadv16f32 addr:$src))),
12997+
(VCVTNEPS2BF16Zrm addr:$src)>;
1298812998
// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
1298912999
}
1299013000

llvm/lib/Target/X86/X86InstrSSE.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8289,6 +8289,11 @@ let Predicates = [HasAVXNECONVERT] in {
82898289
f256mem>, T8PS;
82908290
let checkVEXPredicate = 1 in
82918291
defm VCVTNEPS2BF16 : VCVTNEPS2BF16_BASE, VEX, T8XS, ExplicitVEXPrefix;
8292+
8293+
def : Pat<(v8bf16 (X86vfpround (v8f32 VR256:$src))),
8294+
(VCVTNEPS2BF16Yrr VR256:$src)>;
8295+
def : Pat<(v8bf16 (X86vfpround (loadv8f32 addr:$src))),
8296+
(VCVTNEPS2BF16Yrm addr:$src)>;
82928297
}
82938298

82948299
def : InstAlias<"vcvtneps2bf16x\t{$src, $dst|$dst, $src}",

llvm/test/CodeGen/X86/avxneconvert-intrinsics.ll

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ define <8 x bfloat> @test_int_x86_vcvtneps2bf16128(<4 x float> %A) {
198198
; CHECK-LABEL: test_int_x86_vcvtneps2bf16128:
199199
; CHECK: # %bb.0:
200200
; CHECK-NEXT: {vex} vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0xc4,0xe2,0x7a,0x72,0xc0]
201-
; CHECK-NEXT: # kill: def $xmm1 killed $xmm0
202201
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
203202
%ret = call <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float> %A)
204203
ret <8 x bfloat> %ret
@@ -209,7 +208,6 @@ define <8 x bfloat> @test_int_x86_vcvtneps2bf16256(<8 x float> %A) {
209208
; CHECK-LABEL: test_int_x86_vcvtneps2bf16256:
210209
; CHECK: # %bb.0:
211210
; CHECK-NEXT: {vex} vcvtneps2bf16 %ymm0, %xmm0 # encoding: [0xc4,0xe2,0x7e,0x72,0xc0]
212-
; CHECK-NEXT: # kill: def $xmm1 killed $xmm0
213211
; CHECK-NEXT: vzeroupper # encoding: [0xc5,0xf8,0x77]
214212
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
215213
%ret = call <8 x bfloat> @llvm.x86.vcvtneps2bf16256(<8 x float> %A)

0 commit comments

Comments
 (0)