Skip to content

Commit bb2aa1a

Browse files
authored
[MLIR][NVVM] Update support for conversions to f8x2 and f6x2 types (llvm#137781)
This change: - Adds the `cvt.f32x2.to.f8x2`, `cvt.f16x2.to.f8x2`, and `cvt.bf16x2.to.f8x2` Ops to the NVVM dialect for the conversions to `.e4m3x2`, `e5m2x2`, and `.ue8m0x2` types. - Renames the recently added `cvt.to.f6x2` Op to `cvt.f32x2.to.f6x2` for consistency with the other conversion Ops. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
1 parent 3f1eafa commit bb2aa1a

File tree

5 files changed

+456
-16
lines changed

5 files changed

+456
-16
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

+150-2
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
10791079
let assemblyFormat = "`<` $value `>`";
10801080
}
10811081

1082-
def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1082+
def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2.to.f6x2"> {
10831083
let summary = "Convert a pair of float inputs to f6x2";
10841084
let description = [{
10851085
This Op converts each of the given float inputs to the specified fp6 type.
@@ -1096,6 +1096,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
10961096

10971097
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
10981098
}];
1099+
10991100
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
11001101
let arguments = (ins
11011102
CVTFP6TypeAttr:$type,
@@ -1110,7 +1111,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
11101111
}];
11111112

11121113
string llvmBuilder = [{
1113-
auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
1114+
auto intId = NVVM::CvtF32x2ToF6x2Op::getIntrinsicID($type, $relu);
11141115
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
11151116
if(op.getDst().getType().isInteger(16))
11161117
$dst = packedI16;
@@ -1120,6 +1121,153 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
11201121
}];
11211122
}
11221123

1124+
def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
1125+
def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
1126+
def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
1127+
1128+
def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
1129+
[CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
1130+
let genSpecializedAttr = 0;
1131+
let cppNamespace = "::mlir::NVVM";
1132+
}
1133+
def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
1134+
let assemblyFormat = "`<` $value `>`";
1135+
}
1136+
1137+
def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
1138+
let summary = "Convert a pair of float inputs to f8x2";
1139+
let description = [{
1140+
This Op converts each of the given float inputs to the specified fp8 type.
1141+
The result `dst` is represented as an i16 type or as a vector
1142+
of two i8 types.
1143+
If `dst` is returned as an i16 type, the converted values are packed such
1144+
that the value converted from `a` is stored in the upper 8 bits of `dst`
1145+
and the value converted from `b` is stored in the lower 8 bits of `dst`.
1146+
If `dst` is returned as a vector type, each converted value is stored as an
1147+
i8 element in the vector.
1148+
The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
1149+
The `relu` attribute, when set, lowers to the '.relu' variant of
1150+
the cvt instruction.
1151+
1152+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1153+
}];
1154+
1155+
let hasVerifier = 1;
1156+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1157+
let arguments = (ins
1158+
CVTFP8TypeAttr:$type,
1159+
F32:$a,
1160+
F32:$b,
1161+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1162+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1163+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1164+
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1165+
1166+
let extraClassDeclaration = [{
1167+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1168+
NVVM::FPRoundingMode rnd,
1169+
NVVM::SaturationMode sat,
1170+
bool hasRelu);
1171+
}];
1172+
1173+
string llvmBuilder = [{
1174+
auto intId = NVVM::CvtF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
1175+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1176+
if(op.getDst().getType().isInteger(16))
1177+
$dst = packedI16;
1178+
else
1179+
$dst = builder.CreateBitCast(packedI16,
1180+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1181+
}];
1182+
}
1183+
1184+
def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
1185+
let summary = "Convert an f16x2 input to f8x2";
1186+
let description = [{
1187+
This Op converts the given f16 inputs in an f16x2 vector to the specified
1188+
f8 type.
1189+
The result `dst` is represented as an i16 type or as a vector
1190+
of two i8 types.
1191+
If `dst` is returned as an i16 type, the converted values from `a`
1192+
are packed such that the value converted from the first element of `a`
1193+
is stored in the upper 8 bits of `dst` and the value converted from the
1194+
second element of `a` is stored in the lower 8 bits of `dst`.
1195+
If `dst` is returned as a vector type, each converted value is stored as an
1196+
i8 element in the vector.
1197+
The `relu` attribute, when set, lowers to the '.relu' variant of
1198+
the cvt instruction.
1199+
1200+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1201+
}];
1202+
1203+
let hasVerifier = 1;
1204+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1205+
let arguments = (ins
1206+
CVTFP8TypeAttr:$type,
1207+
VectorOfLengthAndType<[2], [F16]>:$a,
1208+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1209+
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1210+
1211+
let extraClassDeclaration = [{
1212+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1213+
bool hasRelu);
1214+
}];
1215+
1216+
string llvmBuilder = [{
1217+
auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu);
1218+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1219+
if(op.getDst().getType().isInteger(16))
1220+
$dst = packedI16;
1221+
else
1222+
$dst = builder.CreateBitCast(packedI16,
1223+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1224+
}];
1225+
}
1226+
1227+
def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
1228+
let summary = "Convert a pair of bf16 inputs to f8x2";
1229+
let description = [{
1230+
This Op converts the given bf16 inputs in a bf16x2 vector to the specified
1231+
f8 type.
1232+
The result `dst` is represented as an i16 type or as a vector
1233+
of two i8 types.
1234+
If `dst` is returned as an i16 type, the converted values from `a`
1235+
are packed such that the value converted from the first element of `a`
1236+
is stored in the upper 8 bits of `dst` and the value converted from the
1237+
second element of `a` is stored in the lower 8 bits of `dst`.
1238+
If `dst` is returned as a vector type, each converted value is stored as an
1239+
i8 element in the vector.
1240+
The `rnd` and `sat` attributes specify the rounding and saturation modes
1241+
respectively.
1242+
1243+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1244+
}];
1245+
1246+
let hasVerifier = 1;
1247+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1248+
let arguments = (ins
1249+
CVTFP8TypeAttr:$type,
1250+
VectorOfLengthAndType<[2], [BF16]>:$a,
1251+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1252+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
1253+
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1254+
1255+
let extraClassDeclaration = [{
1256+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
1257+
NVVM::SaturationMode sat);
1258+
}];
1259+
1260+
string llvmBuilder = [{
1261+
auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
1262+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1263+
if(op.getDst().getType().isInteger(16))
1264+
$dst = packedI16;
1265+
else
1266+
$dst = builder.CreateBitCast(packedI16,
1267+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1268+
}];
1269+
}
1270+
11231271
//===----------------------------------------------------------------------===//
11241272
// NVVM MMA Ops
11251273
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

+124-5
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,61 @@ LogicalResult CvtFloatToTF32Op::verify() {
133133
return success();
134134
}
135135

136+
LogicalResult CvtF32x2ToF8x2Op::verify() {
137+
using RndMode = NVVM::FPRoundingMode;
138+
using SatMode = NVVM::SaturationMode;
139+
140+
bool isRoundingModeRN = getRnd() == RndMode::RN;
141+
bool isRoundingModeRZ = getRnd() == RndMode::RZ;
142+
bool isRoundingModeRP = getRnd() == RndMode::RP;
143+
bool isSatFinite = getSat() == SatMode::SATFINITE;
144+
145+
bool hasRelu = getRelu();
146+
147+
switch (getType()) {
148+
case CVTFP8Type::E4M3:
149+
case CVTFP8Type::E5M2:
150+
if (!isRoundingModeRN)
151+
return emitOpError("Only RN rounding mode is supported for conversions "
152+
"from f32x2 to .e4m3x2 or .e5m2x2 types");
153+
if (!isSatFinite)
154+
return emitOpError("Only SATFINITE saturation mode is supported for "
155+
"conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
156+
break;
157+
case CVTFP8Type::UE8M0:
158+
if (!(isRoundingModeRZ || isRoundingModeRP))
159+
return emitOpError("Only RZ or RP rounding modes are supported for "
160+
"conversions from f32x2 to .ue8m0x2 type");
161+
if (hasRelu)
162+
return emitOpError("relu not supported for conversions to .ue8m0x2 type");
163+
break;
164+
}
165+
return success();
166+
}
167+
168+
LogicalResult CvtF16x2ToF8x2Op::verify() {
169+
if (getType() == CVTFP8Type::UE8M0)
170+
return emitOpError("Only .e4m3 or .e5m2 types are supported for "
171+
"conversions from f16x2 to f8x2.");
172+
173+
return success();
174+
}
175+
176+
LogicalResult CvtBF16x2ToF8x2Op::verify() {
177+
using RndMode = NVVM::FPRoundingMode;
178+
179+
if (getType() != CVTFP8Type::UE8M0)
180+
return emitOpError(
181+
"Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
182+
183+
auto rnd = getRnd();
184+
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
185+
return emitOpError("Only RZ and RP rounding modes are supported for "
186+
"conversions from bf16x2 to f8x2.");
187+
188+
return success();
189+
}
190+
136191
LogicalResult BulkStoreOp::verify() {
137192
if (getInitVal() != 0)
138193
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1290,17 +1345,81 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12901345
}
12911346
}
12921347

1293-
#define CVT_TO_F6X2_ID_IMPL(type, has_relu) \
1348+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
12941349
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
12951350
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
12961351

1297-
llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
1298-
bool hasRelu) {
1352+
llvm::Intrinsic::ID CvtF32x2ToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
1353+
bool hasRelu) {
12991354
switch (type) {
13001355
case NVVM::CVTFP6Type::E2M3:
1301-
return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu);
1356+
return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
13021357
case NVVM::CVTFP6Type::E3M2:
1303-
return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu);
1358+
return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
1359+
}
1360+
}
1361+
1362+
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1363+
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1364+
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1365+
1366+
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1367+
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1368+
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
1369+
1370+
llvm::Intrinsic::ID CvtF32x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
1371+
NVVM::FPRoundingMode rnd,
1372+
NVVM::SaturationMode sat,
1373+
bool hasRelu) {
1374+
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1375+
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1376+
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1377+
1378+
switch (type) {
1379+
case NVVM::CVTFP8Type::E4M3:
1380+
return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
1381+
case NVVM::CVTFP8Type::E5M2:
1382+
return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
1383+
case NVVM::CVTFP8Type::UE8M0:
1384+
if (hasRoundingModeRZ)
1385+
return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
1386+
else if (hasRoundingModeRP)
1387+
return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
1388+
}
1389+
llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1390+
}
1391+
1392+
#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1393+
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1394+
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1395+
1396+
llvm::Intrinsic::ID CvtF16x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
1397+
bool hasRelu) {
1398+
switch (type) {
1399+
case NVVM::CVTFP8Type::E4M3:
1400+
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1401+
case NVVM::CVTFP8Type::E5M2:
1402+
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1403+
default:
1404+
llvm_unreachable("Invalid CVTFP8Type for CvtF16x2ToF8x2Op");
1405+
}
1406+
}
1407+
1408+
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1409+
has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1410+
: llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1411+
1412+
llvm::Intrinsic::ID
1413+
CvtBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1414+
NVVM::SaturationMode sat) {
1415+
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1416+
switch (rnd) {
1417+
case NVVM::FPRoundingMode::RZ:
1418+
return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
1419+
case NVVM::FPRoundingMode::RP:
1420+
return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
1421+
default:
1422+
llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
13041423
}
13051424
}
13061425

+8-9
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
22

3-
// CHECK-LABEL: @convert_float_to_fp6x2_packed
4-
llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
3+
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
4+
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
55
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
6-
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
6+
%res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
77
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
8-
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
8+
%res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
99
llvm.return
1010
}
1111

12-
// CHECK-LABEL: @convert_float_to_fp6x2_vector
13-
llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
12+
// CHECK-LABEL: @convert_f32x2_to_fp6x2_vector
13+
llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
1414
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
1515
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
16-
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
16+
%res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
1717
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
1818
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
19-
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
19+
%res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
2020
llvm.return
2121
}
22-

0 commit comments

Comments
 (0)