22
22
#include " ../LLVMCommon/MemRefDescriptor.h"
23
23
24
24
#include " llvm/ADT/STLExtras.h"
25
+ #include " llvm/ADT/TypeSwitch.h"
25
26
#include < optional>
26
27
27
28
namespace mlir {
@@ -36,6 +37,7 @@ using namespace mlir::amdgpu;
36
37
constexpr Chipset kGfx908 = Chipset(9 , 0 , 8 );
37
38
constexpr Chipset kGfx90a = Chipset(9 , 0 , 0xa );
38
39
constexpr Chipset kGfx942 = Chipset(9 , 4 , 2 );
40
+ constexpr Chipset kGfx950 = Chipset(9 , 5 , 0 );
39
41
40
42
// / Convert an unsigned number `val` to i32.
41
43
static Value convertUnsignedToI32 (ConversionPatternRewriter &rewriter,
@@ -494,18 +496,33 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
494
496
// / and LLVM AMDGPU intrinsics convention.
495
497
// /
496
498
// / Specifically:
497
- // / 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
498
- // / 2. If the element type is bfloat16, bitcast it to i16.
499
+ // / 1. If the element type is bfloat16, bitcast it to i16.
500
+ // / 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
501
+ // / instead, which is what the f8f6f4 intrinsics use.
502
+ // / 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
503
+ // / integer.
504
+ // /
505
+ // / Note that the type of `input` has already been LLVM type converted:
506
+ // / therefore 8-bit and smaller floats are represented as their corresponding
507
+ // / `iN` integers.
499
508
static Value convertMFMAVectorOperand (ConversionPatternRewriter &rewriter,
500
509
Location loc, Value input) {
501
510
Type inputType = input.getType ();
502
511
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
503
512
if (vectorType.getElementType ().isBF16 ())
504
513
return rewriter.create <LLVM::BitcastOp>(
505
514
loc, vectorType.clone (rewriter.getI16Type ()), input);
506
- if (vectorType.getElementType ().isInteger (8 )) {
515
+ if (vectorType.getElementType ().isInteger (8 ) &&
516
+ vectorType.getNumElements () <= 8 )
507
517
return rewriter.create <LLVM::BitcastOp>(
508
518
loc, rewriter.getIntegerType (vectorType.getNumElements () * 8 ), input);
519
+ if (isa<IntegerType>(vectorType.getElementType ()) &&
520
+ vectorType.getElementTypeBitWidth () <= 8 ) {
521
+ int64_t numWords = llvm::divideCeil (
522
+ vectorType.getNumElements () * vectorType.getElementTypeBitWidth (),
523
+ 32 );
524
+ return rewriter.create <LLVM::BitcastOp>(
525
+ loc, VectorType::get (numWords, rewriter.getI32Type ()), input);
509
526
}
510
527
}
511
528
return input;
@@ -622,12 +639,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
622
639
Chipset chipset) {
623
640
uint32_t m = mfma.getM (), n = mfma.getN (), k = mfma.getK (),
624
641
b = mfma.getBlocks ();
625
- Type sourceElem = mfma.getSourceA ().getType ();
626
- if (auto sourceType = dyn_cast<VectorType>(sourceElem))
627
- sourceElem = sourceType.getElementType ();
628
- Type destElem = mfma.getDestC ().getType ();
629
- if (auto destType = dyn_cast<VectorType>(destElem))
630
- destElem = destType.getElementType ();
642
+ Type sourceElem = getElementTypeOrSelf (mfma.getSourceA ().getType ());
643
+ Type destElem = getElementTypeOrSelf (mfma.getDestC ().getType ());
631
644
632
645
if (sourceElem.isF32 () && destElem.isF32 ()) {
633
646
if (mfma.getReducePrecision () && chipset >= kGfx942 ) {
@@ -649,6 +662,12 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
649
662
}
650
663
651
664
if (sourceElem.isF16 () && destElem.isF32 ()) {
665
+ if (chipset >= kGfx950 ) {
666
+ if (m == 32 && n == 32 && k == 16 && b == 1 )
667
+ return ROCDL::mfma_f32_32x32x16_f16::getOperationName ();
668
+ if (m == 16 && n == 16 && k == 32 && b == 1 )
669
+ return ROCDL::mfma_f32_16x16x32_f16::getOperationName ();
670
+ }
652
671
if (m == 32 && n == 32 && k == 4 && b == 2 )
653
672
return ROCDL::mfma_f32_32x32x4f16::getOperationName ();
654
673
if (m == 16 && n == 16 && k == 4 && b == 4 )
@@ -661,20 +680,25 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
661
680
return ROCDL::mfma_f32_16x16x16f16::getOperationName ();
662
681
}
663
682
664
- if (sourceElem.isBF16 () && destElem.isF32 () && chipset >= kGfx90a ) {
665
- if (m == 32 && n == 32 && k == 4 && b == 2 )
666
- return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName ();
667
- if (m == 16 && n == 16 && k == 4 && b == 4 )
668
- return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName ();
669
- if (m == 4 && n == 4 && k == 4 && b == 16 )
670
- return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName ();
671
- if (m == 32 && n == 32 && k == 8 && b == 1 )
672
- return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName ();
673
- if (m == 16 && n == 16 && k == 16 && b == 1 )
674
- return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName ();
675
- }
676
-
677
683
if (sourceElem.isBF16 () && destElem.isF32 ()) {
684
+ if (chipset >= kGfx950 ) {
685
+ if (m == 32 && n == 32 && k == 16 && b == 1 )
686
+ return ROCDL::mfma_f32_32x32x16_bf16::getOperationName ();
687
+ if (m == 16 && n == 16 && k == 32 && b == 1 )
688
+ return ROCDL::mfma_f32_16x16x32_bf16::getOperationName ();
689
+ }
690
+ if (chipset >= kGfx90a ) {
691
+ if (m == 32 && n == 32 && k == 4 && b == 2 )
692
+ return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName ();
693
+ if (m == 16 && n == 16 && k == 4 && b == 4 )
694
+ return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName ();
695
+ if (m == 4 && n == 4 && k == 4 && b == 16 )
696
+ return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName ();
697
+ if (m == 32 && n == 32 && k == 8 && b == 1 )
698
+ return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName ();
699
+ if (m == 16 && n == 16 && k == 16 && b == 1 )
700
+ return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName ();
701
+ }
678
702
if (m == 32 && n == 32 && k == 2 && b == 2 )
679
703
return ROCDL::mfma_f32_32x32x2bf16::getOperationName ();
680
704
if (m == 16 && n == 16 && k == 2 && b == 4 )
@@ -687,7 +711,13 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
687
711
return ROCDL::mfma_f32_16x16x8bf16::getOperationName ();
688
712
}
689
713
690
- if (isa<IntegerType>(sourceElem) && destElem.isInteger (32 )) {
714
+ if (sourceElem.isInteger (8 ) && destElem.isInteger (32 )) {
715
+ if (chipset >= kGfx950 ) {
716
+ if (m == 32 && n == 32 && k == 32 && b == 1 )
717
+ return ROCDL::mfma_i32_32x32x32_i8::getOperationName ();
718
+ if (m == 16 && n == 16 && k == 64 && b == 1 )
719
+ return ROCDL::mfma_i32_16x16x64_i8::getOperationName ();
720
+ }
691
721
if (m == 32 && n == 32 && k == 4 && b == 2 )
692
722
return ROCDL::mfma_i32_32x32x4i8::getOperationName ();
693
723
if (m == 16 && n == 16 && k == 4 && b == 4 )
@@ -750,6 +780,59 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
750
780
return std::nullopt;
751
781
}
752
782
783
+ static std::optional<uint32_t > mfmaTypeSelectCode (Type mlirElemType) {
784
+ return llvm::TypeSwitch<Type, std::optional<uint32_t >>(mlirElemType)
785
+ .Case ([](Float8E4M3FNType) { return 0u ; })
786
+ .Case ([](Float8E5M2Type) { return 1u ; })
787
+ .Case ([](Float6E2M3FNType) { return 2u ; })
788
+ .Case ([](Float6E3M2FNType) { return 3u ; })
789
+ .Case ([](Float4E2M1FNType) { return 4u ; })
790
+ .Default ([](Type) { return std::nullopt; });
791
+ }
792
+
793
+ // / If there is a scaled MFMA instruction for the input element types `aType`
794
+ // / and `bType`, output type `destType`, problem size M, N, K, and B (number of
795
+ // / blocks) on the given `chipset`, return a tuple consisting of the
796
+ // / OperationName of the intrinsic and the type codes that need to be passed to
797
+ // / that intrinsic. Note that this is also used to implement some un-scaled
798
+ // / MFMAs, since the compiler represents the ordinary instruction as a "scaled"
799
+ // / MFMA with a scale of 0.
800
+ static std::optional<std::tuple<StringRef, uint32_t , uint32_t >>
801
+ mfmaOpToScaledIntrinsic (Type aType, Type bType, Type destType, uint32_t m,
802
+ uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
803
+ aType = getElementTypeOrSelf (aType);
804
+ bType = getElementTypeOrSelf (bType);
805
+ destType = getElementTypeOrSelf (destType);
806
+
807
+ if (chipset < kGfx950 )
808
+ return std::nullopt;
809
+ if (!isa<Float32Type>(destType))
810
+ return std::nullopt;
811
+
812
+ std::optional<uint32_t > aTypeCode = mfmaTypeSelectCode (aType);
813
+ std::optional<uint32_t > bTypeCode = mfmaTypeSelectCode (bType);
814
+ if (!aTypeCode || !bTypeCode)
815
+ return std::nullopt;
816
+
817
+ if (m == 32 && n == 32 && k == 64 && b == 1 )
818
+ return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName (),
819
+ *aTypeCode, *bTypeCode};
820
+ if (m == 16 && n == 16 && k == 128 && b == 1 )
821
+ return std::tuple{
822
+ ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName (), *aTypeCode,
823
+ *bTypeCode};
824
+
825
+ return std::nullopt;
826
+ }
827
+
828
+ static std::optional<std::tuple<StringRef, uint32_t , uint32_t >>
829
+ mfmaOpToScaledIntrinsic (MFMAOp mfma, Chipset chipset) {
830
+ return mfmaOpToScaledIntrinsic (
831
+ mfma.getSourceA ().getType (), mfma.getSourceB ().getType (),
832
+ mfma.getDestC ().getType (), mfma.getM (), mfma.getN (), mfma.getK (),
833
+ mfma.getBlocks (), chipset);
834
+ }
835
+
753
836
// / Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
754
837
// / if one exists. This includes checking to ensure the intrinsic is supported
755
838
// / on the architecture you are compiling for.
@@ -829,16 +912,40 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
829
912
op.getNegateA () | (op.getNegateB () << 1 ) | (op.getNegateC () << 2 );
830
913
}
831
914
std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic (op, chipset);
832
- if (!maybeIntrinsic.has_value ())
915
+ std::optional<std::tuple<StringRef, uint32_t , uint32_t >>
916
+ maybeScaledIntrinsic = mfmaOpToScaledIntrinsic (op, chipset);
917
+ if (!maybeIntrinsic.has_value () && !maybeScaledIntrinsic.has_value ())
833
918
return op.emitOpError (" no intrinsic matching MFMA size on given chipset" );
834
- OperationState loweredOp (loc, *maybeIntrinsic);
919
+
920
+ bool isScaled =
921
+ !maybeIntrinsic.has_value () && maybeScaledIntrinsic.has_value ();
922
+ if (isScaled &&
923
+ (adaptor.getAbid () > 0 || getBlgpField > 0 || op.getCbsz () > 0 )) {
924
+ return op.emitOpError (
925
+ " non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
926
+ " be scaled as those fields are used for type information" );
927
+ }
928
+
929
+ StringRef intrinsicName =
930
+ isScaled ? std::get<0 >(*maybeScaledIntrinsic) : *maybeIntrinsic;
931
+ OperationState loweredOp (loc, intrinsicName);
835
932
loweredOp.addTypes (intrinsicOutType);
836
933
loweredOp.addOperands (
837
934
{convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceA ()),
838
935
convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceB ()),
839
- adaptor.getDestC (), createI32Constant (rewriter, loc, op.getCbsz ()),
840
- createI32Constant (rewriter, loc, op.getAbid ()),
841
- createI32Constant (rewriter, loc, getBlgpField)});
936
+ adaptor.getDestC ()});
937
+ if (isScaled) {
938
+ Value zero = createI32Constant (rewriter, loc, 0 );
939
+ auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
940
+ loweredOp.addOperands ({createI32Constant (rewriter, loc, aTypeCode),
941
+ createI32Constant (rewriter, loc, bTypeCode),
942
+ /* scale A byte=*/ zero, /* scale A=*/ zero,
943
+ /* scale B byte=*/ zero, /* scale B=*/ zero});
944
+ } else {
945
+ loweredOp.addOperands ({createI32Constant (rewriter, loc, op.getCbsz ()),
946
+ createI32Constant (rewriter, loc, op.getAbid ()),
947
+ createI32Constant (rewriter, loc, getBlgpField)});
948
+ };
842
949
Value lowered = rewriter.create (loweredOp)->getResult (0 );
843
950
if (outType != intrinsicOutType)
844
951
lowered = rewriter.create <LLVM::BitcastOp>(loc, outType, lowered);
0 commit comments