Skip to content

Commit c72a751

Browse files
authored
[X86][AMX] Support AMX-TRANSPOSE (llvm#113532)
Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368
1 parent 1e19f0f commit c72a751

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+2768
-139
lines changed

clang/docs/ReleaseNotes.rst

+1
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ X86 Support
676676
- Supported intrinsics for ``MOVRS AND AVX10.2``.
677677
* Supported intrinsics of ``_mm(256|512)_(mask(z))_loadrs_epi(8|16|32|64)``.
678678
- Support ISA of ``AMX-FP8``.
679+
- Support ISA of ``AMX-TRANSPOSE``.
679680

680681
Arm and AArch64 Support
681682
^^^^^^^^^^^^^^^^^^^^^^^

clang/include/clang/Basic/BuiltinsX86_64.def

+11
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i",
128128
TARGET_BUILTIN(__builtin_ia32_tdpfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp16")
129129
TARGET_BUILTIN(__builtin_ia32_tcmmimfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-complex")
130130
TARGET_BUILTIN(__builtin_ia32_tcmmrlfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-complex")
131+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose")
132+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0t1_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose")
133+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose")
134+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1t1_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose")
135+
TARGET_BUILTIN(__builtin_ia32_ttransposed_internal, "V256iUsUsV256i", "n", "amx-transpose")
131136
// AMX
132137
TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
133138
TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile")
@@ -148,6 +153,12 @@ TARGET_BUILTIN(__builtin_ia32_ptwrite64, "vUOi", "n", "ptwrite")
148153
TARGET_BUILTIN(__builtin_ia32_tcmmimfp16ps, "vIUcIUcIUc", "n", "amx-complex")
149154
TARGET_BUILTIN(__builtin_ia32_tcmmrlfp16ps, "vIUcIUcIUc", "n", "amx-complex")
150155

156+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0, "vIUcvC*z", "n", "amx-transpose")
157+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0t1, "vIUcvC*z", "n","amx-transpose")
158+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1, "vIUcvC*z", "n", "amx-transpose")
159+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1t1, "vIUcvC*z", "n","amx-transpose")
160+
TARGET_BUILTIN(__builtin_ia32_ttransposed, "vIUcIUc", "n", "amx-transpose")
161+
151162
TARGET_BUILTIN(__builtin_ia32_prefetchi, "vvC*Ui", "nc", "prefetchi")
152163
TARGET_BUILTIN(__builtin_ia32_cmpccxadd32, "Siv*SiSiIi", "n", "cmpccxadd")
153164
TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiSLLi*SLLiSLLiIi", "n", "cmpccxadd")

clang/include/clang/Driver/Options.td

+2
Original file line numberDiff line numberDiff line change
@@ -6301,6 +6301,8 @@ def mamx_fp8 : Flag<["-"], "mamx-fp8">, Group<m_x86_Features_Group>;
63016301
def mno_amx_fp8 : Flag<["-"], "mno-amx-fp8">, Group<m_x86_Features_Group>;
63026302
def mamx_tile : Flag<["-"], "mamx-tile">, Group<m_x86_Features_Group>;
63036303
def mno_amx_tile : Flag<["-"], "mno-amx-tile">, Group<m_x86_Features_Group>;
6304+
def mamx_transpose : Flag<["-"], "mamx-transpose">, Group<m_x86_Features_Group>;
6305+
def mno_amx_transpose : Flag<["-"], "mno-amx-transpose">, Group<m_x86_Features_Group>;
63046306
def mcmpccxadd : Flag<["-"], "mcmpccxadd">, Group<m_x86_Features_Group>;
63056307
def mno_cmpccxadd : Flag<["-"], "mno-cmpccxadd">, Group<m_x86_Features_Group>;
63066308
def msse : Flag<["-"], "msse">, Group<m_x86_Features_Group>;

clang/lib/Basic/Targets/X86.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
430430
HasAMXCOMPLEX = true;
431431
} else if (Feature == "+amx-fp8") {
432432
HasAMXFP8 = true;
433+
} else if (Feature == "+amx-transpose") {
434+
HasAMXTRANSPOSE = true;
433435
} else if (Feature == "+cmpccxadd") {
434436
HasCMPCCXADD = true;
435437
} else if (Feature == "+raoint") {
@@ -951,6 +953,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts,
951953
Builder.defineMacro("__AMX_COMPLEX__");
952954
if (HasAMXFP8)
953955
Builder.defineMacro("__AMX_FP8__");
956+
if (HasAMXTRANSPOSE)
957+
Builder.defineMacro("__AMX_TRANSPOSE__");
954958
if (HasCMPCCXADD)
955959
Builder.defineMacro("__CMPCCXADD__");
956960
if (HasRAOINT)
@@ -1079,9 +1083,10 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const {
10791083
.Case("amx-bf16", true)
10801084
.Case("amx-complex", true)
10811085
.Case("amx-fp16", true)
1086+
.Case("amx-fp8", true)
10821087
.Case("amx-int8", true)
10831088
.Case("amx-tile", true)
1084-
.Case("amx-fp8", true)
1089+
.Case("amx-transpose", true)
10851090
.Case("avx", true)
10861091
.Case("avx10.1-256", true)
10871092
.Case("avx10.1-512", true)
@@ -1198,9 +1203,10 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const {
11981203
.Case("amx-bf16", HasAMXBF16)
11991204
.Case("amx-complex", HasAMXCOMPLEX)
12001205
.Case("amx-fp16", HasAMXFP16)
1206+
.Case("amx-fp8", HasAMXFP8)
12011207
.Case("amx-int8", HasAMXINT8)
12021208
.Case("amx-tile", HasAMXTILE)
1203-
.Case("amx-fp8", HasAMXFP8)
1209+
.Case("amx-transpose", HasAMXTRANSPOSE)
12041210
.Case("avx", SSELevel >= AVX)
12051211
.Case("avx10.1-256", HasAVX10_1)
12061212
.Case("avx10.1-512", HasAVX10_1_512)

clang/lib/Basic/Targets/X86.h

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
158158
bool HasAMXBF16 = false;
159159
bool HasAMXCOMPLEX = false;
160160
bool HasAMXFP8 = false;
161+
bool HasAMXTRANSPOSE = false;
161162
bool HasSERIALIZE = false;
162163
bool HasTSXLDTRK = false;
163164
bool HasUSERMSR = false;

clang/lib/CodeGen/CGBuiltin.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -16994,6 +16994,58 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
1699416994
// instruction, but it will create a memset that won't be optimized away.
1699516995
return Builder.CreateMemSet(Ops[0], Ops[1], Ops[2], Align(1), true);
1699616996
}
16997+
// Corresponding to intrisics which will return 2 tiles (tile0_tile1).
16998+
case X86::BI__builtin_ia32_t2rpntlvwz0_internal:
16999+
case X86::BI__builtin_ia32_t2rpntlvwz0t1_internal:
17000+
case X86::BI__builtin_ia32_t2rpntlvwz1_internal:
17001+
case X86::BI__builtin_ia32_t2rpntlvwz1t1_internal: {
17002+
Intrinsic::ID IID;
17003+
switch (BuiltinID) {
17004+
default:
17005+
llvm_unreachable("Unsupported intrinsic!");
17006+
case X86::BI__builtin_ia32_t2rpntlvwz0_internal:
17007+
IID = Intrinsic::x86_t2rpntlvwz0_internal;
17008+
break;
17009+
case X86::BI__builtin_ia32_t2rpntlvwz0t1_internal:
17010+
IID = Intrinsic::x86_t2rpntlvwz0t1_internal;
17011+
break;
17012+
case X86::BI__builtin_ia32_t2rpntlvwz1_internal:
17013+
IID = Intrinsic::x86_t2rpntlvwz1_internal;
17014+
break;
17015+
case X86::BI__builtin_ia32_t2rpntlvwz1t1_internal:
17016+
IID = Intrinsic::x86_t2rpntlvwz1t1_internal;
17017+
break;
17018+
}
17019+
17020+
// Ops = (Row0, Col0, Col1, DstPtr0, DstPtr1, SrcPtr, Stride)
17021+
Value *Call = Builder.CreateCall(CGM.getIntrinsic(IID),
17022+
{Ops[0], Ops[1], Ops[2], Ops[5], Ops[6]});
17023+
17024+
auto *PtrTy = E->getArg(3)->getType()->getAs<PointerType>();
17025+
assert(PtrTy && "arg3 must be of pointer type");
17026+
QualType PtreeTy = PtrTy->getPointeeType();
17027+
llvm::Type *TyPtee = ConvertType(PtreeTy);
17028+
17029+
// Bitcast amx type (x86_amx) to vector type (256 x i32)
17030+
// Then store tile0 into DstPtr0
17031+
Value *T0 = Builder.CreateExtractValue(Call, 0);
17032+
Value *VecT0 = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
17033+
{TyPtee}, {T0});
17034+
Builder.CreateDefaultAlignedStore(VecT0, Ops[3]);
17035+
17036+
// Then store tile1 into DstPtr1
17037+
Value *T1 = Builder.CreateExtractValue(Call, 1);
17038+
Value *VecT1 = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
17039+
{TyPtee}, {T1});
17040+
Value *Store = Builder.CreateDefaultAlignedStore(VecT1, Ops[4]);
17041+
17042+
// Note: Here we escape directly use x86_tilestored64_internal to store
17043+
// the results due to it can't make sure the Mem written scope. This may
17044+
// cause shapes reloads after first amx intrinsic, which current amx reg-
17045+
// ister allocation has no ability to handle it.
17046+
17047+
return Store;
17048+
}
1699717049
case X86::BI__ud2:
1699817050
// llvm.trap makes a ud2a instruction on x86.
1699917051
return EmitTrapCall(Intrinsic::trap);

clang/lib/Headers/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,9 @@ set(x86_files
148148
ammintrin.h
149149
amxcomplexintrin.h
150150
amxfp16intrin.h
151-
amxintrin.h
152151
amxfp8intrin.h
152+
amxintrin.h
153+
amxtransposeintrin.h
153154
avx10_2_512bf16intrin.h
154155
avx10_2_512convertintrin.h
155156
avx10_2_512minmaxintrin.h

clang/lib/Headers/amxintrin.h

+2
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ static __inline__ void __DEFAULT_FN_ATTRS_TILE _tile_release(void) {
232232
/// bytes. Since there is no 2D type in llvm IR, we use vector type to
233233
/// represent 2D tile and the fixed size is maximum amx tile register size.
234234
typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64)));
235+
typedef int _tile1024i_1024a
236+
__attribute__((__vector_size__(1024), __aligned__(1024)));
235237

236238
/// This is internal intrinsic. C/C++ user should avoid calling it directly.
237239
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8

0 commit comments

Comments
 (0)