@@ -71,12 +71,20 @@ class X86LowerAMXIntrinsics {
71
71
Value *createTileLoadStoreLoops (BasicBlock *Start, BasicBlock *End,
72
72
IRBuilderBase &B, Value *Row, Value *Col,
73
73
Value *Ptr , Value *Stride, Value *Tile);
74
- Value *createTileDPBSSDLoops (BasicBlock *Start, BasicBlock *End,
75
- IRBuilderBase &B, Value *Row, Value *Col,
76
- Value *K, Value *Acc, Value *LHS, Value *RHS);
74
+ template <Intrinsic::ID IntrID>
75
+ typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
76
+ IntrID == Intrinsic::x86_tdpbf16ps_internal,
77
+ Value *>::type
78
+ createTileDPLoops (BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
79
+ Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
80
+ Value *RHS);
77
81
template <bool IsTileLoad>
78
82
bool lowerTileLoadStore (Instruction *TileLoadStore);
79
- bool lowerTileDPBSSD (Instruction *TileDPBSSD);
83
+ template <Intrinsic::ID IntrID>
84
+ typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
85
+ IntrID == Intrinsic::x86_tdpbf16ps_internal,
86
+ bool >::type
87
+ lowerTileDP (Instruction *TileDP);
80
88
bool lowerTileZero (Instruction *TileZero);
81
89
};
82
90
@@ -213,9 +221,16 @@ Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
213
221
}
214
222
}
215
223
216
- Value *X86LowerAMXIntrinsics::createTileDPBSSDLoops (
217
- BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
218
- Value *Col, Value *K, Value *Acc, Value *LHS, Value *RHS) {
224
+ template <Intrinsic::ID IntrID>
225
+ typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
226
+ IntrID == Intrinsic::x86_tdpbf16ps_internal,
227
+ Value *>::type
228
+ X86LowerAMXIntrinsics::createTileDPLoops (BasicBlock *Start, BasicBlock *End,
229
+ IRBuilderBase &B, Value *Row,
230
+ Value *Col, Value *K, Value *Acc,
231
+ Value *LHS, Value *RHS) {
232
+ std::string IntrinName =
233
+ IntrID == Intrinsic::x86_tdpbssd_internal ? " tiledpbssd" : " tdpbf16ps" ;
219
234
Loop *RowLoop = nullptr ;
220
235
Loop *ColLoop = nullptr ;
221
236
Loop *InnerLoop = nullptr ;
@@ -232,17 +247,18 @@ Value *X86LowerAMXIntrinsics::createTileDPBSSDLoops(
232
247
}
233
248
234
249
BasicBlock *RowBody = createLoop (Start, End, Row, B.getInt16 (1 ),
235
- " tiledpbssd .scalarize.rows" , B, RowLoop);
250
+ IntrinName + " .scalarize.rows" , B, RowLoop);
236
251
BasicBlock *RowLatch = RowBody->getSingleSuccessor ();
237
252
238
253
BasicBlock *ColBody = createLoop (RowBody, RowLatch, Col, B.getInt16 (1 ),
239
- " tiledpbssd.scalarize.cols" , B, ColLoop);
254
+ IntrinName + " .scalarize.cols" , B, ColLoop);
255
+
240
256
BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor ();
241
257
242
258
B.SetInsertPoint (ColBody->getTerminator ());
243
259
BasicBlock *InnerBody =
244
260
createLoop (ColBody, ColLoopLatch, K, B.getInt16 (1 ),
245
- " tiledpbssd .scalarize.inner" , B, InnerLoop);
261
+ IntrinName + " .scalarize.inner" , B, InnerLoop);
246
262
247
263
BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor ();
248
264
BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor ();
@@ -306,39 +322,82 @@ Value *X86LowerAMXIntrinsics::createTileDPBSSDLoops(
306
322
PHINode *VecCPhi = B.CreatePHI (V256I32Ty, 2 , " vec.c.inner.phi" );
307
323
VecCPhi->addIncoming (VecCPhiColLoop, ColBody);
308
324
309
- // tiledpbssd.scalarize.inner.body:
310
- // calculate idxa, idxb
311
- // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
312
- // %elta = extractelement <256 x i32> %veca, i16 %idxa
313
- // %eltav4i8 = bitcast i32 %elta to <4 x i8>
314
- // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
315
- // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
316
- // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
317
- // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
318
- // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
319
- // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
320
- // %neweltc = add i32 %elt, %acc
321
- // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
322
- // i16 %idxc
323
-
324
325
B.SetInsertPoint (InnerBody->getTerminator ());
325
326
Value *IdxA =
326
327
B.CreateAdd (B.CreateMul (CurrentRow, B.getInt16 (16 )), CurrentInner);
327
328
Value *IdxB =
328
329
B.CreateAdd (B.CreateMul (CurrentInner, B.getInt16 (16 )), CurrentCol);
329
-
330
- FixedVectorType *V4I8Ty = FixedVectorType::get (B.getInt8Ty (), 4 );
331
- FixedVectorType *V4I32Ty = FixedVectorType::get (B.getInt32Ty (), 4 );
332
- Value *EltC = B.CreateExtractElement (VecCPhi, IdxC);
333
- Value *EltA = B.CreateExtractElement (VecA, IdxA);
334
- Value *SubVecA = B.CreateBitCast (EltA, V4I8Ty);
335
- Value *EltB = B.CreateExtractElement (VecB, IdxB);
336
- Value *SubVecB = B.CreateBitCast (EltB, V4I8Ty);
337
- Value *SEXTSubVecB = B.CreateSExt (SubVecB, V4I32Ty);
338
- Value *SEXTSubVecA = B.CreateSExt (SubVecA, V4I32Ty);
339
- Value *SubVecR = B.CreateAddReduce (B.CreateMul (SEXTSubVecA, SEXTSubVecB));
340
- Value *ResElt = B.CreateAdd (EltC, SubVecR);
341
- Value *NewVecC = B.CreateInsertElement (VecCPhi, ResElt, IdxC);
330
+ Value *NewVecC = nullptr ;
331
+
332
+ if (IntrID == Intrinsic::x86_tdpbssd_internal) {
333
+ // tiledpbssd.scalarize.inner.body:
334
+ // calculate idxa, idxb
335
+ // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
336
+ // %elta = extractelement <256 x i32> %veca, i16 %idxa
337
+ // %eltav4i8 = bitcast i32 %elta to <4 x i8>
338
+ // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
339
+ // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
340
+ // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
341
+ // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
342
+ // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
343
+ // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
344
+ // %neweltc = add i32 %elt, %acc
345
+ // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
346
+ // i16 %idxc
347
+ FixedVectorType *V4I8Ty = FixedVectorType::get (B.getInt8Ty (), 4 );
348
+ FixedVectorType *V4I32Ty = FixedVectorType::get (B.getInt32Ty (), 4 );
349
+ Value *EltC = B.CreateExtractElement (VecCPhi, IdxC);
350
+ Value *EltA = B.CreateExtractElement (VecA, IdxA);
351
+ Value *SubVecA = B.CreateBitCast (EltA, V4I8Ty);
352
+ Value *EltB = B.CreateExtractElement (VecB, IdxB);
353
+ Value *SubVecB = B.CreateBitCast (EltB, V4I8Ty);
354
+ Value *SEXTSubVecB = B.CreateSExt (SubVecB, V4I32Ty);
355
+ Value *SEXTSubVecA = B.CreateSExt (SubVecA, V4I32Ty);
356
+ Value *SubVecR = B.CreateAddReduce (B.CreateMul (SEXTSubVecA, SEXTSubVecB));
357
+ Value *ResElt = B.CreateAdd (EltC, SubVecR);
358
+ NewVecC = B.CreateInsertElement (VecCPhi, ResElt, IdxC);
359
+ } else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) {
360
+ // tiledpbf16ps.scalarize.inner.body:
361
+ // calculate idxa, idxb, idxc
362
+ // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
363
+ // %eltcf32 = bitcast i32 %eltc to float
364
+ // %elta = extractelement <256 x i32> %veca, i16 %idxa
365
+ // %eltav2i16 = bitcast i32 %elta to <2 x i16>
366
+ // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
367
+ // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
368
+ // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
369
+ // x i32> <i32 2, i32 0, i32 3, i32 1>
370
+ // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
371
+ // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
372
+ // i32> <i32 2, i32 0, i32 3, i32 1>
373
+ // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
374
+ // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
375
+ // %acc = call float
376
+ // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
377
+ // %neweltc = bitcast float %acc to i32
378
+ // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
379
+ // i16 %idxc
380
+ // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
381
+ // i16 %idxc
382
+ FixedVectorType *V2I16Ty = FixedVectorType::get (B.getInt16Ty (), 2 );
383
+ FixedVectorType *V2F32Ty = FixedVectorType::get (B.getFloatTy (), 2 );
384
+ Value *EltC = B.CreateExtractElement (VecCPhi, IdxC);
385
+ Value *EltCF32 = B.CreateBitCast (EltC, B.getFloatTy ());
386
+ Value *EltA = B.CreateExtractElement (VecA, IdxA);
387
+ Value *SubVecA = B.CreateBitCast (EltA, V2I16Ty);
388
+ Value *EltB = B.CreateExtractElement (VecB, IdxB);
389
+ Value *SubVecB = B.CreateBitCast (EltB, V2I16Ty);
390
+ Value *ZeroV2I16 = Constant::getNullValue (V2I16Ty);
391
+ int ShuffleMask[4 ] = {2 , 0 , 3 , 1 };
392
+ auto ShuffleArray = makeArrayRef (ShuffleMask);
393
+ Value *AV2F32 = B.CreateBitCast (
394
+ B.CreateShuffleVector (SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
395
+ Value *BV2F32 = B.CreateBitCast (
396
+ B.CreateShuffleVector (SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
397
+ Value *SubVecR = B.CreateFAddReduce (EltCF32, B.CreateFMul (AV2F32, BV2F32));
398
+ Value *ResElt = B.CreateBitCast (SubVecR, B.getInt32Ty ());
399
+ NewVecC = B.CreateInsertElement (VecCPhi, ResElt, IdxC);
400
+ }
342
401
343
402
// tiledpbssd.scalarize.cols.latch:
344
403
// %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
@@ -357,14 +416,17 @@ Value *X86LowerAMXIntrinsics::createTileDPBSSDLoops(
357
416
return NewVecD;
358
417
}
359
418
360
- bool X86LowerAMXIntrinsics::lowerTileDPBSSD (Instruction *TileDPBSSD) {
419
+ template <Intrinsic::ID IntrID>
420
+ typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
421
+ IntrID == Intrinsic::x86_tdpbf16ps_internal,
422
+ bool >::type
423
+ X86LowerAMXIntrinsics::lowerTileDP (Instruction *TileDP) {
361
424
Value *M, *N, *K, *C, *A, *B;
362
- match (TileDPBSSD, m_Intrinsic<Intrinsic::x86_tdpbssd_internal>(
363
- m_Value (M), m_Value (N), m_Value (K), m_Value (C),
364
- m_Value (A), m_Value (B)));
365
- Instruction *InsertI = TileDPBSSD;
366
- IRBuilder<> PreBuilder (TileDPBSSD);
367
- PreBuilder.SetInsertPoint (TileDPBSSD);
425
+ match (TileDP, m_Intrinsic<IntrID>(m_Value (M), m_Value (N), m_Value (K),
426
+ m_Value (C), m_Value (A), m_Value (B)));
427
+ Instruction *InsertI = TileDP;
428
+ IRBuilder<> PreBuilder (TileDP);
429
+ PreBuilder.SetInsertPoint (TileDP);
368
430
// We visit the loop with (m, n/4, k/4):
369
431
// %n_dword = lshr i16 %n, 2
370
432
// %k_dword = lshr i16 %k, 2
@@ -373,26 +435,25 @@ bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) {
373
435
BasicBlock *Start = InsertI->getParent ();
374
436
BasicBlock *End =
375
437
SplitBlock (InsertI->getParent (), InsertI, &DTU, LI, nullptr , " continue" );
376
- IRBuilder<> Builder (TileDPBSSD );
377
- Value *ResVec =
378
- createTileDPBSSDLoops (Start, End, Builder, M, NDWord, KDWord, C, A, B);
438
+ IRBuilder<> Builder (TileDP );
439
+ Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
440
+ KDWord, C, A, B);
379
441
// we cannot assume there always be bitcast after tiledpbssd. So we need to
380
442
// insert one bitcast as required
381
443
Builder.SetInsertPoint (End->getFirstNonPHI ());
382
444
Value *ResAMX =
383
445
Builder.CreateBitCast (ResVec, Type::getX86_AMXTy (Builder.getContext ()));
384
- // Delete tiledpbssd intrinsic and do some clean-up.
385
- for (auto UI = TileDPBSSD->use_begin (), UE = TileDPBSSD->use_end ();
386
- UI != UE;) {
446
+ // Delete TileDP intrinsic and do some clean-up.
447
+ for (auto UI = TileDP->use_begin (), UE = TileDP->use_end (); UI != UE;) {
387
448
Instruction *I = cast<Instruction>((UI++)->getUser ());
388
449
Value *Vec;
389
450
if (match (I, m_BitCast (m_Value (Vec)))) {
390
451
I->replaceAllUsesWith (ResVec);
391
452
I->eraseFromParent ();
392
453
}
393
454
}
394
- TileDPBSSD ->replaceAllUsesWith (ResAMX);
395
- TileDPBSSD ->eraseFromParent ();
455
+ TileDP ->replaceAllUsesWith (ResAMX);
456
+ TileDP ->eraseFromParent ();
396
457
return true ;
397
458
}
398
459
@@ -469,6 +530,7 @@ bool X86LowerAMXIntrinsics::visit() {
469
530
case Intrinsic::x86_tileloadd64_internal:
470
531
case Intrinsic::x86_tilestored64_internal:
471
532
case Intrinsic::x86_tilezero_internal:
533
+ case Intrinsic::x86_tdpbf16ps_internal:
472
534
WorkList.push_back (Inst);
473
535
break ;
474
536
default :
@@ -481,7 +543,10 @@ bool X86LowerAMXIntrinsics::visit() {
481
543
for (auto *Inst : WorkList) {
482
544
switch (Inst->getIntrinsicID ()) {
483
545
case Intrinsic::x86_tdpbssd_internal:
484
- C = lowerTileDPBSSD (Inst) || C;
546
+ C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
547
+ break ;
548
+ case Intrinsic::x86_tdpbf16ps_internal:
549
+ C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
485
550
break ;
486
551
case Intrinsic::x86_tileloadd64_internal:
487
552
C = lowerTileLoadStore<true >(Inst) || C;
0 commit comments