Skip to content

Commit b0e9b00

Browse files
authored
[NVPTX] Make nvptx mma instructions convergent. (#96521)
We are running into NVPTX backend generating wrong code for an input: ``` %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...) if laneid == 0: ret else: store %0 ``` The backend reorder the instruction (as an effect of `MachineSink` pass) to ``` if laneid == 0: ret else: %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...) store %0 ``` This is incorrect because `mma` is a warp instruction which needs all threads to sync before performing the operation instead of being guarded by a specific thread id. It should be similar as the shuffle instruction `shfl` in terms of warp level sync, and `shfl` is marked as `isConvergent = true`. Apply `isConvergent = true` to `mma` instructions.
1 parent 7ea63b9 commit b0e9b00

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6725,6 +6725,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
67256725
# FragC.regstring # ";";
67266726
}
67276727

6728+
let isConvergent = true in {
67286729
defset list<WMMA_INSTR> WMMAs = {
67296730
foreach layout_a = ["row", "col"] in {
67306731
foreach layout_b = ["row", "col"] in {
@@ -6746,6 +6747,7 @@ defset list<WMMA_INSTR> WMMAs = {
67466747
} // layout_b
67476748
} // layout_a
67486749
} // defset
6750+
}
67496751

67506752
// MMA
67516753
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6775,6 +6777,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
67756777
# FragC.regstring # ";";
67766778
}
67776779

6780+
let isConvergent = true in {
67786781
defset list<WMMA_INSTR> MMAs = {
67796782
foreach layout_a = ["row", "col"] in {
67806783
foreach layout_b = ["row", "col"] in {
@@ -6794,6 +6797,7 @@ defset list<WMMA_INSTR> MMAs = {
67946797
} // layout_b
67956798
} // layout_a
67966799
} // defset
6800+
}
67976801

67986802
//
67996803
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
2+
3+
declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #1
4+
5+
declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #0
6+
7+
; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
8+
; CHECK-LABEL: no_reorder_mma_and_laneid_check
9+
define dso_local void @no_reorder_mma_and_laneid_check(ptr %arg, ptr %arg1) {
10+
bb:
11+
; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
12+
; CHECK: laneid
13+
%i = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 10, i32 10, i32 8, float 0.0, float 0.0, float 0.0, float 0.0)
14+
%i3 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
15+
%i4 = icmp eq i32 %i3, 0
16+
br i1 %i4, label %bb5, label %bb8
17+
18+
bb5: ; preds = %bb
19+
%i6 = extractvalue { float, float, float, float } %i, 0
20+
%i7 = getelementptr float, ptr %arg, i64 0
21+
store float %i6, ptr %i7, align 4
22+
br label %bb8
23+
24+
bb8: ; preds = %bb5, %bb
25+
ret void
26+
}

0 commit comments

Comments
 (0)