@@ -1079,7 +1079,7 @@ def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
1079
1079
let assemblyFormat = "`<` $value `>`";
1080
1080
}
1081
1081
1082
- def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1082
+ def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2 .to.f6x2"> {
1083
1083
let summary = "Convert a pair of float inputs to f6x2";
1084
1084
let description = [{
1085
1085
This Op converts each of the given float inputs to the specified fp6 type.
@@ -1096,6 +1096,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1096
1096
1097
1097
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1098
1098
}];
1099
+
1099
1100
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1100
1101
let arguments = (ins
1101
1102
CVTFP6TypeAttr:$type,
@@ -1110,7 +1111,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1110
1111
}];
1111
1112
1112
1113
string llvmBuilder = [{
1113
- auto intId = NVVM::CvtToF6x2Op ::getIntrinsicID($type, $relu);
1114
+ auto intId = NVVM::CvtF32x2ToF6x2Op ::getIntrinsicID($type, $relu);
1114
1115
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1115
1116
if(op.getDst().getType().isInteger(16))
1116
1117
$dst = packedI16;
@@ -1120,6 +1121,153 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1120
1121
}];
1121
1122
}
1122
1123
1124
+ def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
1125
+ def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
1126
+ def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
1127
+
1128
+ def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
1129
+ [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
1130
+ let genSpecializedAttr = 0;
1131
+ let cppNamespace = "::mlir::NVVM";
1132
+ }
1133
+ def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
1134
+ let assemblyFormat = "`<` $value `>`";
1135
+ }
1136
+
1137
+ def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
1138
+ let summary = "Convert a pair of float inputs to f8x2";
1139
+ let description = [{
1140
+ This Op converts each of the given float inputs to the specified fp8 type.
1141
+ The result `dst` is represented as an i16 type or as a vector
1142
+ of two i8 types.
1143
+ If `dst` is returned as an i16 type, the converted values are packed such
1144
+ that the value converted from `a` is stored in the upper 8 bits of `dst`
1145
+ and the value converted from `b` is stored in the lower 8 bits of `dst`.
1146
+ If `dst` is returned as a vector type, each converted value is stored as an
1147
+ i8 element in the vector.
1148
+ The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
1149
+ The `relu` attribute, when set, lowers to the '.relu' variant of
1150
+ the cvt instruction.
1151
+
1152
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1153
+ }];
1154
+
1155
+ let hasVerifier = 1;
1156
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1157
+ let arguments = (ins
1158
+ CVTFP8TypeAttr:$type,
1159
+ F32:$a,
1160
+ F32:$b,
1161
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1162
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1163
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
1164
+ let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1165
+
1166
+ let extraClassDeclaration = [{
1167
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1168
+ NVVM::FPRoundingMode rnd,
1169
+ NVVM::SaturationMode sat,
1170
+ bool hasRelu);
1171
+ }];
1172
+
1173
+ string llvmBuilder = [{
1174
+ auto intId = NVVM::CvtF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
1175
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1176
+ if(op.getDst().getType().isInteger(16))
1177
+ $dst = packedI16;
1178
+ else
1179
+ $dst = builder.CreateBitCast(packedI16,
1180
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1181
+ }];
1182
+ }
1183
+
1184
+ def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
1185
+ let summary = "Convert an f16x2 input to f8x2";
1186
+ let description = [{
1187
+ This Op converts the given f16 inputs in an f16x2 vector to the specified
1188
+ f8 type.
1189
+ The result `dst` is represented as an i16 type or as a vector
1190
+ of two i8 types.
1191
+ If `dst` is returned as an i16 type, the converted values from `a`
1192
+ are packed such that the value converted from the first element of `a`
1193
+ is stored in the upper 8 bits of `dst` and the value converted from the
1194
+ second element of `a` is stored in the lower 8 bits of `dst`.
1195
+ If `dst` is returned as a vector type, each converted value is stored as an
1196
+ i8 element in the vector.
1197
+ The `relu` attribute, when set, lowers to the '.relu' variant of
1198
+ the cvt instruction.
1199
+
1200
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1201
+ }];
1202
+
1203
+ let hasVerifier = 1;
1204
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1205
+ let arguments = (ins
1206
+ CVTFP8TypeAttr:$type,
1207
+ VectorOfLengthAndType<[2], [F16]>:$a,
1208
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
1209
+ let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1210
+
1211
+ let extraClassDeclaration = [{
1212
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1213
+ bool hasRelu);
1214
+ }];
1215
+
1216
+ string llvmBuilder = [{
1217
+ auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu);
1218
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1219
+ if(op.getDst().getType().isInteger(16))
1220
+ $dst = packedI16;
1221
+ else
1222
+ $dst = builder.CreateBitCast(packedI16,
1223
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1224
+ }];
1225
+ }
1226
+
1227
+ def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
1228
+ let summary = "Convert a pair of bf16 inputs to f8x2";
1229
+ let description = [{
1230
+ This Op converts the given bf16 inputs in a bf16x2 vector to the specified
1231
+ f8 type.
1232
+ The result `dst` is represented as an i16 type or as a vector
1233
+ of two i8 types.
1234
+ If `dst` is returned as an i16 type, the converted values from `a`
1235
+ are packed such that the value converted from the first element of `a`
1236
+ is stored in the upper 8 bits of `dst` and the value converted from the
1237
+ second element of `a` is stored in the lower 8 bits of `dst`.
1238
+ If `dst` is returned as a vector type, each converted value is stored as an
1239
+ i8 element in the vector.
1240
+ The `rnd` and `sat` attributes specify the rounding and saturation modes
1241
+ respectively.
1242
+
1243
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1244
+ }];
1245
+
1246
+ let hasVerifier = 1;
1247
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1248
+ let arguments = (ins
1249
+ CVTFP8TypeAttr:$type,
1250
+ VectorOfLengthAndType<[2], [BF16]>:$a,
1251
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1252
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
1253
+ let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1254
+
1255
+ let extraClassDeclaration = [{
1256
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
1257
+ NVVM::SaturationMode sat);
1258
+ }];
1259
+
1260
+ string llvmBuilder = [{
1261
+ auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
1262
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1263
+ if(op.getDst().getType().isInteger(16))
1264
+ $dst = packedI16;
1265
+ else
1266
+ $dst = builder.CreateBitCast(packedI16,
1267
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1268
+ }];
1269
+ }
1270
+
1123
1271
//===----------------------------------------------------------------------===//
1124
1272
// NVVM MMA Ops
1125
1273
//===----------------------------------------------------------------------===//
0 commit comments