Skip to content

Commit 155e188

Browse files
authored
[NVPTX] Add intrinsics and clang builtins for conversions of f4x2 type (llvm#139244)
This change adds intrinsics and clang builtins for the cvt instruction variants of type (FP4) `.e2m1x2`. introduced in PTX 8.6 for `sm_100a`, `sm_101a`, and `sm_120a`. Tests are added in `NVPTX/convert-sm100a.ll` and `clang/test/CodeGen/builtins-nvptx.c` and verified through ptxas 12.8.0. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
1 parent 4e63e04 commit 155e188

File tree

6 files changed

+149
-1
lines changed

6 files changed

+149
-1
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,12 @@ def __nvvm_e2m3x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(sh
620620
def __nvvm_e3m2x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
621621
def __nvvm_e3m2x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
622622

623+
def __nvvm_ff_to_e2m1x2_rn_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
624+
def __nvvm_ff_to_e2m1x2_rn_relu_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
625+
626+
def __nvvm_e2m1x2_to_f16x2_rn : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
627+
def __nvvm_e2m1x2_to_f16x2_rn_relu : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(short)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
628+
623629
def __nvvm_ff_to_ue8m0x2_rz : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
624630
def __nvvm_ff_to_ue8m0x2_rz_satfinite : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;
625631
def __nvvm_ff_to_ue8m0x2_rp : NVPTXBuiltinSMAndPTX<"short(float, float)", SM<"100a", [SM_101a, SM_120a]>, PTX86>;

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,26 @@ __device__ void nvvm_cvt_sm100a_sm101a_sm120a() {
11271127
// CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 19532)
11281128
__nvvm_e3m2x2_to_f16x2_rn_relu(0x4C4C);
11291129

1130+
// CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
1131+
// CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
1132+
// CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float 1.000000e+00, float 1.000000e+00)
1133+
__nvvm_ff_to_e2m1x2_rn_satfinite(1.0f, 1.0f);
1134+
1135+
// CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
1136+
// CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
1137+
// CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float 1.000000e+00, float 1.000000e+00)
1138+
__nvvm_ff_to_e2m1x2_rn_relu_satfinite(1.0f, 1.0f);
1139+
1140+
// CHECK_PTX86_SM100a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76)
1141+
// CHECK_PTX86_SM101a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76)
1142+
// CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 76)
1143+
__nvvm_e2m1x2_to_f16x2_rn(0x004C);
1144+
1145+
// CHECK_PTX86_SM100a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76)
1146+
// CHECK_PTX86_SM101a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76)
1147+
// CHECK_PTX86_SM120a: call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 76)
1148+
__nvvm_e2m1x2_to_f16x2_rn_relu(0x004C);
1149+
11301150
// CHECK_PTX86_SM100a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00)
11311151
// CHECK_PTX86_SM101a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00)
11321152
// CHECK_PTX86_SM120a: call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float 1.000000e+00, float 1.000000e+00)

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,10 +1293,19 @@ let TargetPrefix = "nvvm" in {
12931293
}
12941294
}
12951295

1296+
// FP4 conversions.
1297+
foreach relu = ["", "_relu"] in {
1298+
def int_nvvm_ff_to_e2m1x2_rn # relu # _satfinite : NVVMBuiltin,
1299+
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
1300+
1301+
def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin,
1302+
DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
1303+
}
1304+
12961305
// UE8M0x2 conversions.
12971306
foreach rmode = ["_rz", "_rp"] in {
12981307
foreach satmode = ["", "_satfinite"] in {
1299-
defvar suffix = !strconcat(rmode, satmode);
1308+
defvar suffix = rmode # satmode;
13001309
def int_nvvm_ff_to_ue8m0x2 # suffix : NVVMBuiltin,
13011310
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
13021311

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,23 @@ let hasSideEffects = false in {
714714
# type # " \t$dst, $src;", []>;
715715
}
716716

717+
// FP4 conversions.
718+
def CVT_e2m1x2_f32_sf : NVPTXInst<(outs Int16Regs:$dst),
719+
(ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
720+
!strconcat("{{ \n\t",
721+
".reg .b8 \t%e2m1x2_out; \n\t",
722+
"cvt${mode:base}.satfinite${mode:relu}.e2m1x2.f32 \t%e2m1x2_out, $src1, $src2; \n\t",
723+
"cvt.u16.u8 \t$dst, %e2m1x2_out; \n\t",
724+
"}}"), []>;
725+
726+
def CVT_f16x2_e2m1x2 : NVPTXInst<(outs Int32Regs:$dst),
727+
(ins Int16Regs:$src, CvtMode:$mode),
728+
!strconcat("{{ \n\t",
729+
".reg .b8 \t%e2m1x2_in; \n\t",
730+
"cvt.u8.u16 \t%e2m1x2_in, $src; \n\t",
731+
"cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t",
732+
"}}"), []>;
733+
717734
// UE8M0x2 conversions.
718735
class CVT_f32_to_ue8m0x2<string sat = ""> :
719736
NVPTXInst<(outs Int16Regs:$dst),

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,6 +2003,20 @@ def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn i16:$a),
20032003
def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn_relu i16:$a),
20042004
(CVT_f16x2_e3m2x2 $a, CvtRN_RELU)>,
20052005
Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
2006+
2007+
def : Pat<(int_nvvm_ff_to_e2m1x2_rn_satfinite f32:$a, f32:$b),
2008+
(CVT_e2m1x2_f32_sf $a, $b, CvtRN)>,
2009+
Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
2010+
def : Pat<(int_nvvm_ff_to_e2m1x2_rn_relu_satfinite f32:$a, f32:$b),
2011+
(CVT_e2m1x2_f32_sf $a, $b, CvtRN_RELU)>,
2012+
Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
2013+
2014+
def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn Int16Regs:$a),
2015+
(CVT_f16x2_e2m1x2 $a, CvtRN)>,
2016+
Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
2017+
def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn_relu Int16Regs:$a),
2018+
(CVT_f16x2_e2m1x2 $a, CvtRN_RELU)>,
2019+
Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
20062020

20072021
def : Pat<(int_nvvm_ff_to_ue8m0x2_rz f32:$a, f32:$b),
20082022
(CVT_ue8m0x2_f32 $a, $b, CvtRZ)>,

llvm/test/CodeGen/NVPTX/convert-sm100a.ll

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,85 @@ define <2 x bfloat> @cvt_bf16x2_ue8m0x2(i16 %in) {
288288
%val = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %in)
289289
ret <2 x bfloat> %val
290290
}
291+
292+
define i16 @cvt_rn_sf_e2m1x2_f32(float %f1, float %f2) {
293+
; CHECK-LABEL: cvt_rn_sf_e2m1x2_f32(
294+
; CHECK: {
295+
; CHECK-NEXT: .reg .b16 %rs<2>;
296+
; CHECK-NEXT: .reg .b32 %r<2>;
297+
; CHECK-NEXT: .reg .b32 %f<3>;
298+
; CHECK-EMPTY:
299+
; CHECK-NEXT: // %bb.0:
300+
; CHECK-NEXT: ld.param.b32 %f1, [cvt_rn_sf_e2m1x2_f32_param_0];
301+
; CHECK-NEXT: ld.param.b32 %f2, [cvt_rn_sf_e2m1x2_f32_param_1];
302+
; CHECK-NEXT: {
303+
; CHECK-NEXT: .reg .b8 %e2m1x2_out;
304+
; CHECK-NEXT: cvt.rn.satfinite.e2m1x2.f32 %e2m1x2_out, %f1, %f2;
305+
; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out;
306+
; CHECK-NEXT: }
307+
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
308+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
309+
; CHECK-NEXT: ret;
310+
%val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %f1, float %f2)
311+
ret i16 %val
312+
}
313+
314+
define i16 @cvt_rn_relu_sf_e2m1x2_f32(float %f1, float %f2) {
315+
; CHECK-LABEL: cvt_rn_relu_sf_e2m1x2_f32(
316+
; CHECK: {
317+
; CHECK-NEXT: .reg .b16 %rs<2>;
318+
; CHECK-NEXT: .reg .b32 %r<2>;
319+
; CHECK-NEXT: .reg .b32 %f<3>;
320+
; CHECK-EMPTY:
321+
; CHECK-NEXT: // %bb.0:
322+
; CHECK-NEXT: ld.param.b32 %f1, [cvt_rn_relu_sf_e2m1x2_f32_param_0];
323+
; CHECK-NEXT: ld.param.b32 %f2, [cvt_rn_relu_sf_e2m1x2_f32_param_1];
324+
; CHECK-NEXT: {
325+
; CHECK-NEXT: .reg .b8 %e2m1x2_out;
326+
; CHECK-NEXT: cvt.rn.satfinite.relu.e2m1x2.f32 %e2m1x2_out, %f1, %f2;
327+
; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out;
328+
; CHECK-NEXT: }
329+
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
330+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
331+
; CHECK-NEXT: ret;
332+
%val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %f1, float %f2)
333+
ret i16 %val
334+
}
335+
336+
define <2 x half> @cvt_rn_f16x2_e2m1x2(i16 %in) {
337+
; CHECK-LABEL: cvt_rn_f16x2_e2m1x2(
338+
; CHECK: {
339+
; CHECK-NEXT: .reg .b16 %rs<2>;
340+
; CHECK-NEXT: .reg .b32 %r<2>;
341+
; CHECK-EMPTY:
342+
; CHECK-NEXT: // %bb.0:
343+
; CHECK-NEXT: ld.param.b16 %rs1, [cvt_rn_f16x2_e2m1x2_param_0];
344+
; CHECK-NEXT: {
345+
; CHECK-NEXT: .reg .b8 %e2m1x2_in;
346+
; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1;
347+
; CHECK-NEXT: cvt.rn.f16x2.e2m1x2 %r1, %e2m1x2_in;
348+
; CHECK-NEXT: }
349+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
350+
; CHECK-NEXT: ret;
351+
%val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %in)
352+
ret <2 x half> %val
353+
}
354+
355+
define <2 x half> @cvt_rn_relu_f16x2_e2m1x2(i16 %in) {
356+
; CHECK-LABEL: cvt_rn_relu_f16x2_e2m1x2(
357+
; CHECK: {
358+
; CHECK-NEXT: .reg .b16 %rs<2>;
359+
; CHECK-NEXT: .reg .b32 %r<2>;
360+
; CHECK-EMPTY:
361+
; CHECK-NEXT: // %bb.0:
362+
; CHECK-NEXT: ld.param.b16 %rs1, [cvt_rn_relu_f16x2_e2m1x2_param_0];
363+
; CHECK-NEXT: {
364+
; CHECK-NEXT: .reg .b8 %e2m1x2_in;
365+
; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1;
366+
; CHECK-NEXT: cvt.rn.relu.f16x2.e2m1x2 %r1, %e2m1x2_in;
367+
; CHECK-NEXT: }
368+
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
369+
; CHECK-NEXT: ret;
370+
%val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %in)
371+
ret <2 x half> %val
372+
}

0 commit comments

Comments
 (0)