Skip to content

Commit 113f077

Browse files
[X86] Pass to transform tdpbf16ps intrinsics to scalar operation.
In previous patch https://reviews.llvm.org/D93594, we only scalarize tilezero, tileload, tilestore and tiledpbssd. In this patch we scalarize tdpbf16ps intrinsic. Reviewed By: pengfei Differential Revision: https://reviews.llvm.org/D96110
1 parent 8fab9f8 commit 113f077

File tree

2 files changed

+200
-56
lines changed

2 files changed

+200
-56
lines changed

llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp

+119-54
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,20 @@ class X86LowerAMXIntrinsics {
7171
Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
7272
IRBuilderBase &B, Value *Row, Value *Col,
7373
Value *Ptr, Value *Stride, Value *Tile);
74-
Value *createTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End,
75-
IRBuilderBase &B, Value *Row, Value *Col,
76-
Value *K, Value *Acc, Value *LHS, Value *RHS);
74+
template <Intrinsic::ID IntrID>
75+
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
76+
IntrID == Intrinsic::x86_tdpbf16ps_internal,
77+
Value *>::type
78+
createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
79+
Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
80+
Value *RHS);
7781
template <bool IsTileLoad>
7882
bool lowerTileLoadStore(Instruction *TileLoadStore);
79-
bool lowerTileDPBSSD(Instruction *TileDPBSSD);
83+
template <Intrinsic::ID IntrID>
84+
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
85+
IntrID == Intrinsic::x86_tdpbf16ps_internal,
86+
bool>::type
87+
lowerTileDP(Instruction *TileDP);
8088
bool lowerTileZero(Instruction *TileZero);
8189
};
8290

@@ -213,9 +221,16 @@ Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
213221
}
214222
}
215223

216-
Value *X86LowerAMXIntrinsics::createTileDPBSSDLoops(
217-
BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
218-
Value *Col, Value *K, Value *Acc, Value *LHS, Value *RHS) {
224+
template <Intrinsic::ID IntrID>
225+
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
226+
IntrID == Intrinsic::x86_tdpbf16ps_internal,
227+
Value *>::type
228+
X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
229+
IRBuilderBase &B, Value *Row,
230+
Value *Col, Value *K, Value *Acc,
231+
Value *LHS, Value *RHS) {
232+
std::string IntrinName =
233+
IntrID == Intrinsic::x86_tdpbssd_internal ? "tiledpbssd" : "tdpbf16ps";
219234
Loop *RowLoop = nullptr;
220235
Loop *ColLoop = nullptr;
221236
Loop *InnerLoop = nullptr;
@@ -232,17 +247,18 @@ Value *X86LowerAMXIntrinsics::createTileDPBSSDLoops(
232247
}
233248

234249
BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
235-
"tiledpbssd.scalarize.rows", B, RowLoop);
250+
IntrinName + ".scalarize.rows", B, RowLoop);
236251
BasicBlock *RowLatch = RowBody->getSingleSuccessor();
237252

238253
BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
239-
"tiledpbssd.scalarize.cols", B, ColLoop);
254+
IntrinName + ".scalarize.cols", B, ColLoop);
255+
240256
BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
241257

242258
B.SetInsertPoint(ColBody->getTerminator());
243259
BasicBlock *InnerBody =
244260
createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
245-
"tiledpbssd.scalarize.inner", B, InnerLoop);
261+
IntrinName + ".scalarize.inner", B, InnerLoop);
246262

247263
BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
248264
BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
@@ -306,39 +322,82 @@ Value *X86LowerAMXIntrinsics::createTileDPBSSDLoops(
306322
PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
307323
VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
308324

309-
// tiledpbssd.scalarize.inner.body:
310-
// calculate idxa, idxb
311-
// %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
312-
// %elta = extractelement <256 x i32> %veca, i16 %idxa
313-
// %eltav4i8 = bitcast i32 %elta to <4 x i8>
314-
// %eltb = extractelement <256 x i32> %vecb, i16 %idxb
315-
// %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
316-
// %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
317-
// %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
318-
// %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
319-
// %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
320-
// %neweltc = add i32 %elt, %acc
321-
// %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
322-
// i16 %idxc
323-
324325
B.SetInsertPoint(InnerBody->getTerminator());
325326
Value *IdxA =
326327
B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
327328
Value *IdxB =
328329
B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
329-
330-
FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
331-
FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
332-
Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
333-
Value *EltA = B.CreateExtractElement(VecA, IdxA);
334-
Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
335-
Value *EltB = B.CreateExtractElement(VecB, IdxB);
336-
Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
337-
Value *SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
338-
Value *SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
339-
Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
340-
Value *ResElt = B.CreateAdd(EltC, SubVecR);
341-
Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
330+
Value *NewVecC = nullptr;
331+
332+
if (IntrID == Intrinsic::x86_tdpbssd_internal) {
333+
// tiledpbssd.scalarize.inner.body:
334+
// calculate idxa, idxb
335+
// %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
336+
// %elta = extractelement <256 x i32> %veca, i16 %idxa
337+
// %eltav4i8 = bitcast i32 %elta to <4 x i8>
338+
// %eltb = extractelement <256 x i32> %vecb, i16 %idxb
339+
// %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
340+
// %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
341+
// %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
342+
// %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
343+
// %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
344+
// %neweltc = add i32 %elt, %acc
345+
// %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
346+
// i16 %idxc
347+
FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
348+
FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
349+
Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
350+
Value *EltA = B.CreateExtractElement(VecA, IdxA);
351+
Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
352+
Value *EltB = B.CreateExtractElement(VecB, IdxB);
353+
Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
354+
Value *SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
355+
Value *SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
356+
Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
357+
Value *ResElt = B.CreateAdd(EltC, SubVecR);
358+
NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
359+
} else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) {
360+
// tiledpbf16ps.scalarize.inner.body:
361+
// calculate idxa, idxb, idxc
362+
// %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
363+
// %eltcf32 = bitcast i32 %eltc to float
364+
// %elta = extractelement <256 x i32> %veca, i16 %idxa
365+
// %eltav2i16 = bitcast i32 %elta to <2 x i16>
366+
// %eltb = extractelement <256 x i32> %vecb, i16 %idxb
367+
// %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
368+
// %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
369+
// x i32> <i32 2, i32 0, i32 3, i32 1>
370+
// %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
371+
// %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
372+
// i32> <i32 2, i32 0, i32 3, i32 1>
373+
// %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
374+
// %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
375+
// %acc = call float
376+
// @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
377+
// %neweltc = bitcast float %acc to i32
378+
// %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
379+
// i16 %idxc
380+
// %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
381+
// i16 %idxc
382+
FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
383+
FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
384+
Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
385+
Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
386+
Value *EltA = B.CreateExtractElement(VecA, IdxA);
387+
Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
388+
Value *EltB = B.CreateExtractElement(VecB, IdxB);
389+
Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
390+
Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
391+
int ShuffleMask[4] = {2, 0, 3, 1};
392+
auto ShuffleArray = makeArrayRef(ShuffleMask);
393+
Value *AV2F32 = B.CreateBitCast(
394+
B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
395+
Value *BV2F32 = B.CreateBitCast(
396+
B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
397+
Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
398+
Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
399+
NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
400+
}
342401

343402
// tiledpbssd.scalarize.cols.latch:
344403
// %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
@@ -357,14 +416,17 @@ Value *X86LowerAMXIntrinsics::createTileDPBSSDLoops(
357416
return NewVecD;
358417
}
359418

360-
bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) {
419+
template <Intrinsic::ID IntrID>
420+
typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
421+
IntrID == Intrinsic::x86_tdpbf16ps_internal,
422+
bool>::type
423+
X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
361424
Value *M, *N, *K, *C, *A, *B;
362-
match(TileDPBSSD, m_Intrinsic<Intrinsic::x86_tdpbssd_internal>(
363-
m_Value(M), m_Value(N), m_Value(K), m_Value(C),
364-
m_Value(A), m_Value(B)));
365-
Instruction *InsertI = TileDPBSSD;
366-
IRBuilder<> PreBuilder(TileDPBSSD);
367-
PreBuilder.SetInsertPoint(TileDPBSSD);
425+
match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
426+
m_Value(C), m_Value(A), m_Value(B)));
427+
Instruction *InsertI = TileDP;
428+
IRBuilder<> PreBuilder(TileDP);
429+
PreBuilder.SetInsertPoint(TileDP);
368430
// We visit the loop with (m, n/4, k/4):
369431
// %n_dword = lshr i16 %n, 2
370432
// %k_dword = lshr i16 %k, 2
@@ -373,26 +435,25 @@ bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) {
373435
BasicBlock *Start = InsertI->getParent();
374436
BasicBlock *End =
375437
SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
376-
IRBuilder<> Builder(TileDPBSSD);
377-
Value *ResVec =
378-
createTileDPBSSDLoops(Start, End, Builder, M, NDWord, KDWord, C, A, B);
438+
IRBuilder<> Builder(TileDP);
439+
Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
440+
KDWord, C, A, B);
379441
// we cannot assume there always be bitcast after tiledpbssd. So we need to
380442
// insert one bitcast as required
381443
Builder.SetInsertPoint(End->getFirstNonPHI());
382444
Value *ResAMX =
383445
Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
384-
// Delete tiledpbssd intrinsic and do some clean-up.
385-
for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end();
386-
UI != UE;) {
446+
// Delete TileDP intrinsic and do some clean-up.
447+
for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) {
387448
Instruction *I = cast<Instruction>((UI++)->getUser());
388449
Value *Vec;
389450
if (match(I, m_BitCast(m_Value(Vec)))) {
390451
I->replaceAllUsesWith(ResVec);
391452
I->eraseFromParent();
392453
}
393454
}
394-
TileDPBSSD->replaceAllUsesWith(ResAMX);
395-
TileDPBSSD->eraseFromParent();
455+
TileDP->replaceAllUsesWith(ResAMX);
456+
TileDP->eraseFromParent();
396457
return true;
397458
}
398459

@@ -469,6 +530,7 @@ bool X86LowerAMXIntrinsics::visit() {
469530
case Intrinsic::x86_tileloadd64_internal:
470531
case Intrinsic::x86_tilestored64_internal:
471532
case Intrinsic::x86_tilezero_internal:
533+
case Intrinsic::x86_tdpbf16ps_internal:
472534
WorkList.push_back(Inst);
473535
break;
474536
default:
@@ -481,7 +543,10 @@ bool X86LowerAMXIntrinsics::visit() {
481543
for (auto *Inst : WorkList) {
482544
switch (Inst->getIntrinsicID()) {
483545
case Intrinsic::x86_tdpbssd_internal:
484-
C = lowerTileDPBSSD(Inst) || C;
546+
C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
547+
break;
548+
case Intrinsic::x86_tdpbf16ps_internal:
549+
C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
485550
break;
486551
case Intrinsic::x86_tileloadd64_internal:
487552
C = lowerTileLoadStore<true>(Inst) || C;

llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll

+81-2
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ entry:
9797
ret void
9898
}
9999

100-
define dso_local void @test_amx_dp(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 {
101-
; CHECK-LABEL: @test_amx_dp(
100+
define dso_local void @test_amx_dpbssd(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 {
101+
; CHECK-LABEL: @test_amx_dpbssd(
102102
; CHECK-NEXT: entry:
103103
; CHECK-NEXT: [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx
104104
; CHECK-NEXT: [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx
@@ -172,6 +172,84 @@ entry:
172172
ret void
173173
}
174174

175+
define dso_local void @test_amx_dpbf16ps(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 {
176+
; CHECK-LABEL: @test_amx_dpbf16ps(
177+
; CHECK-NEXT: entry:
178+
; CHECK-NEXT: [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx
179+
; CHECK-NEXT: [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx
180+
; CHECK-NEXT: [[C_AMX:%.*]] = bitcast <256 x i32> [[C:%.*]] to x86_amx
181+
; CHECK-NEXT: [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
182+
; CHECK-NEXT: [[TMP1:%.*]] = lshr i16 [[K:%.*]], 2
183+
; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_ROWS_HEADER:%.*]]
184+
; CHECK: tdpbf16ps.scalarize.rows.header:
185+
; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TDPBF16PS_SCALARIZE_ROWS_STEP:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH:%.*]] ]
186+
; CHECK-NEXT: [[VEC_C_PHI_ROW:%.*]] = phi <256 x i32> [ [[C]], [[ENTRY]] ], [ [[TMP21:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] ]
187+
; CHECK-NEXT: [[VEC_D_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP23:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] ]
188+
; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_ROWS_BODY:%.*]]
189+
; CHECK: tdpbf16ps.scalarize.rows.body:
190+
; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_COLS_HEADER:%.*]]
191+
; CHECK: tdpbf16ps.scalarize.cols.header:
192+
; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TDPBF16PS_SCALARIZE_COLS_STEP:%.*]], [[TDPBF16PS_SCALARIZE_COLS_LATCH:%.*]] ]
193+
; CHECK-NEXT: [[VEC_C_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_ROW]], [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TMP21]], [[TDPBF16PS_SCALARIZE_COLS_LATCH]] ]
194+
; CHECK-NEXT: [[VEC_D_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_D_PHI_ROW]], [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TMP23]], [[TDPBF16PS_SCALARIZE_COLS_LATCH]] ]
195+
; CHECK-NEXT: [[TMP2:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 16
196+
; CHECK-NEXT: [[TMP3:%.*]] = add i16 [[TMP2]], [[TDPBF16PS_SCALARIZE_COLS_IV]]
197+
; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_COLS_BODY:%.*]]
198+
; CHECK: tdpbf16ps.scalarize.cols.body:
199+
; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_HEADER:%.*]]
200+
; CHECK: tdpbf16ps.scalarize.inner.header:
201+
; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_IV:%.*]] = phi i16 [ 0, [[TDPBF16PS_SCALARIZE_COLS_BODY]] ], [ [[TDPBF16PS_SCALARIZE_INNER_STEP:%.*]], [[TDPBF16PS_SCALARIZE_INNER_LATCH:%.*]] ]
202+
; CHECK-NEXT: [[VEC_C_INNER_PHI:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_COL]], [[TDPBF16PS_SCALARIZE_COLS_BODY]] ], [ [[TMP21]], [[TDPBF16PS_SCALARIZE_INNER_LATCH]] ]
203+
; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_BODY:%.*]]
204+
; CHECK: tdpbf16ps.scalarize.inner.body:
205+
; CHECK-NEXT: [[TMP4:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 16
206+
; CHECK-NEXT: [[TMP5:%.*]] = add i16 [[TMP4]], [[TDPBF16PS_SCALARIZE_INNER_IV]]
207+
; CHECK-NEXT: [[TMP6:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_INNER_IV]], 16
208+
; CHECK-NEXT: [[TMP7:%.*]] = add i16 [[TMP6]], [[TDPBF16PS_SCALARIZE_COLS_IV]]
209+
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <256 x i32> [[VEC_C_INNER_PHI]], i16 [[TMP3]]
210+
; CHECK-NEXT: [[TMP9:%.*]] = bitcast i32 [[TMP8]] to float
211+
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP5]]
212+
; CHECK-NEXT: [[TMP11:%.*]] = bitcast i32 [[TMP10]] to <2 x i16>
213+
; CHECK-NEXT: [[TMP12:%.*]] = extractelement <256 x i32> [[B]], i16 [[TMP7]]
214+
; CHECK-NEXT: [[TMP13:%.*]] = bitcast i32 [[TMP12]] to <2 x i16>
215+
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <2 x i16> [[TMP11]], <2 x i16> zeroinitializer, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
216+
; CHECK-NEXT: [[TMP15:%.*]] = bitcast <4 x i16> [[TMP14]] to <2 x float>
217+
; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <2 x i16> [[TMP13]], <2 x i16> zeroinitializer, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
218+
; CHECK-NEXT: [[TMP17:%.*]] = bitcast <4 x i16> [[TMP16]] to <2 x float>
219+
; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x float> [[TMP15]], [[TMP17]]
220+
; CHECK-NEXT: [[TMP19:%.*]] = call float @llvm.vector.reduce.fadd.v2f32(float [[TMP9]], <2 x float> [[TMP18]])
221+
; CHECK-NEXT: [[TMP20:%.*]] = bitcast float [[TMP19]] to i32
222+
; CHECK-NEXT: [[TMP21]] = insertelement <256 x i32> [[VEC_C_INNER_PHI]], i32 [[TMP20]], i16 [[TMP3]]
223+
; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_LATCH]]
224+
; CHECK: tdpbf16ps.scalarize.inner.latch:
225+
; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_INNER_IV]], 1
226+
; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_INNER_STEP]], [[TMP1]]
227+
; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_INNER_COND]], label [[TDPBF16PS_SCALARIZE_INNER_HEADER]], label [[TDPBF16PS_SCALARIZE_COLS_LATCH]]
228+
; CHECK: tdpbf16ps.scalarize.cols.latch:
229+
; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_COLS_IV]], 1
230+
; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_COLS_STEP]], [[TMP0]]
231+
; CHECK-NEXT: [[TMP22:%.*]] = extractelement <256 x i32> [[TMP21]], i16 [[TMP3]]
232+
; CHECK-NEXT: [[TMP23]] = insertelement <256 x i32> [[VEC_D_PHI_COL]], i32 [[TMP22]], i16 [[TMP3]]
233+
; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_COLS_COND]], label [[TDPBF16PS_SCALARIZE_COLS_HEADER]], label [[TDPBF16PS_SCALARIZE_ROWS_LATCH]]
234+
; CHECK: tdpbf16ps.scalarize.rows.latch:
235+
; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 1
236+
; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
237+
; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_ROWS_COND]], label [[TDPBF16PS_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
238+
; CHECK: continue:
239+
; CHECK-NEXT: [[TMP24:%.*]] = bitcast <256 x i32> [[TMP23]] to x86_amx
240+
; CHECK-NEXT: store <256 x i32> [[TMP23]], <256 x i32>* [[VPTR:%.*]], align 64
241+
; CHECK-NEXT: ret void
242+
;
243+
entry:
244+
%a.amx = bitcast <256 x i32> %a to x86_amx
245+
%b.amx = bitcast <256 x i32> %b to x86_amx
246+
%c.amx = bitcast <256 x i32> %c to x86_amx
247+
%acc = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %row, i16 %col, i16 %k, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx)
248+
%vec = bitcast x86_amx %acc to <256 x i32>
249+
store <256 x i32> %vec, <256 x i32>* %vptr, align 64
250+
ret void
251+
}
252+
175253
define dso_local void @test_amx_store(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr, <256 x i32> %vec) #0 {
176254
; CHECK-LABEL: @test_amx_store(
177255
; CHECK-NEXT: entry:
@@ -232,6 +310,7 @@ entry:
232310
declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
233311
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
234312
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
313+
declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
235314
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
236315

237316
attributes #0 = { noinline nounwind optnone }

0 commit comments

Comments
 (0)