Skip to content

Commit 25622aa

Browse files
krzysz00kuhar
andauthored
[mlir][AMDGPU] Add gfx950 MFMAs to the amdgpu.mfma op (llvm#133553)
This commit extends the lowering of amdgpu.mfma to handle the new double-rate MFMAs in gfx950 and adds tests for these operations. It also adds support for MFMAs on small floats (f6 and f4), which are implented using the "scaled" MFMA intrinsic with a scale value of 0 in order to have an unscaled MFMA. This commit does not add a `amdgpu.scaled_mfma` operation, as that is future work. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 69c5049 commit 25622aa

File tree

5 files changed

+204
-40
lines changed

5 files changed

+204
-40
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,10 +650,12 @@ def AMDGPU_MFMAPermBAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_MFMAPermB,
650650
// mfma
651651
def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64,
652652
VectorOfLengthAndType<[2], [F32]>,
653-
VectorOfLengthAndType<[4], [F16]>,
654-
VectorOfLengthAndType<[2, 4], [BF16]>,
655-
VectorOfLengthAndType<[4, 8], [I8]>,
656-
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>;
653+
VectorOfLengthAndType<[4, 8], [F16]>,
654+
VectorOfLengthAndType<[2, 4, 8], [BF16]>,
655+
VectorOfLengthAndType<[4, 8, 16], [I8]>,
656+
VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>,
657+
VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>,
658+
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
657659
def MFMAOutTypes : AnyTypeOf<[F64,
658660
VectorOfLengthAndType<[4, 16, 32], [F32]>,
659661
VectorOfLengthAndType<[4, 16, 32], [I32]>,

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 135 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "../LLVMCommon/MemRefDescriptor.h"
2323

2424
#include "llvm/ADT/STLExtras.h"
25+
#include "llvm/ADT/TypeSwitch.h"
2526
#include <optional>
2627

2728
namespace mlir {
@@ -36,6 +37,7 @@ using namespace mlir::amdgpu;
3637
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
3738
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
3839
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
40+
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
3941

4042
/// Convert an unsigned number `val` to i32.
4143
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -494,18 +496,33 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
494496
/// and LLVM AMDGPU intrinsics convention.
495497
///
496498
/// Specifically:
497-
/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
498-
/// 2. If the element type is bfloat16, bitcast it to i16.
499+
/// 1. If the element type is bfloat16, bitcast it to i16.
500+
/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
501+
/// instead, which is what the f8f6f4 intrinsics use.
502+
/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
503+
/// integer.
504+
///
505+
/// Note that the type of `input` has already been LLVM type converted:
506+
/// therefore 8-bit and smaller floats are represented as their corresponding
507+
/// `iN` integers.
499508
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
500509
Location loc, Value input) {
501510
Type inputType = input.getType();
502511
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
503512
if (vectorType.getElementType().isBF16())
504513
return rewriter.create<LLVM::BitcastOp>(
505514
loc, vectorType.clone(rewriter.getI16Type()), input);
506-
if (vectorType.getElementType().isInteger(8)) {
515+
if (vectorType.getElementType().isInteger(8) &&
516+
vectorType.getNumElements() <= 8)
507517
return rewriter.create<LLVM::BitcastOp>(
508518
loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
519+
if (isa<IntegerType>(vectorType.getElementType()) &&
520+
vectorType.getElementTypeBitWidth() <= 8) {
521+
int64_t numWords = llvm::divideCeil(
522+
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
523+
32);
524+
return rewriter.create<LLVM::BitcastOp>(
525+
loc, VectorType::get(numWords, rewriter.getI32Type()), input);
509526
}
510527
}
511528
return input;
@@ -622,12 +639,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
622639
Chipset chipset) {
623640
uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
624641
b = mfma.getBlocks();
625-
Type sourceElem = mfma.getSourceA().getType();
626-
if (auto sourceType = dyn_cast<VectorType>(sourceElem))
627-
sourceElem = sourceType.getElementType();
628-
Type destElem = mfma.getDestC().getType();
629-
if (auto destType = dyn_cast<VectorType>(destElem))
630-
destElem = destType.getElementType();
642+
Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
643+
Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
631644

632645
if (sourceElem.isF32() && destElem.isF32()) {
633646
if (mfma.getReducePrecision() && chipset >= kGfx942) {
@@ -649,6 +662,12 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
649662
}
650663

651664
if (sourceElem.isF16() && destElem.isF32()) {
665+
if (chipset >= kGfx950) {
666+
if (m == 32 && n == 32 && k == 16 && b == 1)
667+
return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
668+
if (m == 16 && n == 16 && k == 32 && b == 1)
669+
return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
670+
}
652671
if (m == 32 && n == 32 && k == 4 && b == 2)
653672
return ROCDL::mfma_f32_32x32x4f16::getOperationName();
654673
if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -661,20 +680,25 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
661680
return ROCDL::mfma_f32_16x16x16f16::getOperationName();
662681
}
663682

664-
if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
665-
if (m == 32 && n == 32 && k == 4 && b == 2)
666-
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
667-
if (m == 16 && n == 16 && k == 4 && b == 4)
668-
return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
669-
if (m == 4 && n == 4 && k == 4 && b == 16)
670-
return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
671-
if (m == 32 && n == 32 && k == 8 && b == 1)
672-
return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
673-
if (m == 16 && n == 16 && k == 16 && b == 1)
674-
return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
675-
}
676-
677683
if (sourceElem.isBF16() && destElem.isF32()) {
684+
if (chipset >= kGfx950) {
685+
if (m == 32 && n == 32 && k == 16 && b == 1)
686+
return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
687+
if (m == 16 && n == 16 && k == 32 && b == 1)
688+
return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
689+
}
690+
if (chipset >= kGfx90a) {
691+
if (m == 32 && n == 32 && k == 4 && b == 2)
692+
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
693+
if (m == 16 && n == 16 && k == 4 && b == 4)
694+
return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
695+
if (m == 4 && n == 4 && k == 4 && b == 16)
696+
return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
697+
if (m == 32 && n == 32 && k == 8 && b == 1)
698+
return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
699+
if (m == 16 && n == 16 && k == 16 && b == 1)
700+
return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
701+
}
678702
if (m == 32 && n == 32 && k == 2 && b == 2)
679703
return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
680704
if (m == 16 && n == 16 && k == 2 && b == 4)
@@ -687,7 +711,13 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
687711
return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
688712
}
689713

690-
if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
714+
if (sourceElem.isInteger(8) && destElem.isInteger(32)) {
715+
if (chipset >= kGfx950) {
716+
if (m == 32 && n == 32 && k == 32 && b == 1)
717+
return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
718+
if (m == 16 && n == 16 && k == 64 && b == 1)
719+
return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
720+
}
691721
if (m == 32 && n == 32 && k == 4 && b == 2)
692722
return ROCDL::mfma_i32_32x32x4i8::getOperationName();
693723
if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -750,6 +780,59 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
750780
return std::nullopt;
751781
}
752782

783+
static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
784+
return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
785+
.Case([](Float8E4M3FNType) { return 0u; })
786+
.Case([](Float8E5M2Type) { return 1u; })
787+
.Case([](Float6E2M3FNType) { return 2u; })
788+
.Case([](Float6E3M2FNType) { return 3u; })
789+
.Case([](Float4E2M1FNType) { return 4u; })
790+
.Default([](Type) { return std::nullopt; });
791+
}
792+
793+
/// If there is a scaled MFMA instruction for the input element types `aType`
794+
/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
795+
/// blocks) on the given `chipset`, return a tuple consisting of the
796+
/// OperationName of the intrinsic and the type codes that need to be passed to
797+
/// that intrinsic. Note that this is also used to implement some un-scaled
798+
/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
799+
/// MFMA with a scale of 0.
800+
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
801+
mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
802+
uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
803+
aType = getElementTypeOrSelf(aType);
804+
bType = getElementTypeOrSelf(bType);
805+
destType = getElementTypeOrSelf(destType);
806+
807+
if (chipset < kGfx950)
808+
return std::nullopt;
809+
if (!isa<Float32Type>(destType))
810+
return std::nullopt;
811+
812+
std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
813+
std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
814+
if (!aTypeCode || !bTypeCode)
815+
return std::nullopt;
816+
817+
if (m == 32 && n == 32 && k == 64 && b == 1)
818+
return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
819+
*aTypeCode, *bTypeCode};
820+
if (m == 16 && n == 16 && k == 128 && b == 1)
821+
return std::tuple{
822+
ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
823+
*bTypeCode};
824+
825+
return std::nullopt;
826+
}
827+
828+
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
829+
mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
830+
return mfmaOpToScaledIntrinsic(
831+
mfma.getSourceA().getType(), mfma.getSourceB().getType(),
832+
mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
833+
mfma.getBlocks(), chipset);
834+
}
835+
753836
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
754837
/// if one exists. This includes checking to ensure the intrinsic is supported
755838
/// on the architecture you are compiling for.
@@ -829,16 +912,40 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
829912
op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
830913
}
831914
std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
832-
if (!maybeIntrinsic.has_value())
915+
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
916+
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
917+
if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
833918
return op.emitOpError("no intrinsic matching MFMA size on given chipset");
834-
OperationState loweredOp(loc, *maybeIntrinsic);
919+
920+
bool isScaled =
921+
!maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
922+
if (isScaled &&
923+
(adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
924+
return op.emitOpError(
925+
"non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
926+
"be scaled as those fields are used for type information");
927+
}
928+
929+
StringRef intrinsicName =
930+
isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
931+
OperationState loweredOp(loc, intrinsicName);
835932
loweredOp.addTypes(intrinsicOutType);
836933
loweredOp.addOperands(
837934
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
838935
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
839-
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
840-
createI32Constant(rewriter, loc, op.getAbid()),
841-
createI32Constant(rewriter, loc, getBlgpField)});
936+
adaptor.getDestC()});
937+
if (isScaled) {
938+
Value zero = createI32Constant(rewriter, loc, 0);
939+
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
940+
loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
941+
createI32Constant(rewriter, loc, bTypeCode),
942+
/*scale A byte=*/zero, /*scale A=*/zero,
943+
/*scale B byte=*/zero, /*scale B=*/zero});
944+
} else {
945+
loweredOp.addOperands({createI32Constant(rewriter, loc, op.getCbsz()),
946+
createI32Constant(rewriter, loc, op.getAbid()),
947+
createI32Constant(rewriter, loc, getBlgpField)});
948+
};
842949
Value lowered = rewriter.create(loweredOp)->getResult(0);
843950
if (outType != intrinsicOutType)
844951
lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,22 +341,24 @@ LogicalResult MFMAOp::verify() {
341341
}
342342

343343
Type sourceBType = getSourceB().getType();
344-
if (sourceElem.isFloat(8)) {
344+
if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
345345
int64_t sourceBLen = 1;
346346
Type sourceBElem = sourceBType;
347347
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
348348
sourceBLen = sourceBVector.getNumElements();
349349
sourceBElem = sourceBVector.getElementType();
350350
}
351-
if (!sourceBElem.isFloat(8))
352-
return emitOpError("expected both source operands to have f8 elements");
351+
if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
352+
!sourceBElem.isFloat(4))
353+
return emitOpError("expected both source operands to have small-float "
354+
"elements if one does");
353355
if (sourceLen != sourceBLen)
354356
return emitOpError(
355-
"expected both f8 source vectors to have the same length");
357+
"expected both small-float source vectors to have the same length");
356358
} else {
357359
if (sourceType != sourceBType)
358-
return emitOpError(
359-
"expected both non-f8 source operand types to match exactly");
360+
return emitOpError("expected both non-small-float source operand types "
361+
"to match exactly");
360362
}
361363
// Normalize the wider integer types the compiler expects to i8
362364
if (sourceElem.isInteger(32)) {

0 commit comments

Comments
 (0)