Skip to content

Commit bd860f9

Browse files
authored
[NVPTX] Add intrinsics for redux.sync f32 instructions (llvm#126664)
Adds NVVM intrinsics, NVPTX codegen and Clang builtins for `redux.sync` f32 instructions introduced in ptx8.6 for sm_100a. Tests added in `CodeGen/NVPTX/redux-sync.ll` and `CodeGenCUDA/redux-builtins.cu` and verified through ptxas 12.8.0. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync
1 parent 8a0914c commit bd860f9

File tree

5 files changed

+212
-0
lines changed

5 files changed

+212
-0
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.td

+8
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,14 @@ def __nvvm_redux_sync_umax : NVPTXBuiltinSMAndPTX<"unsigned int(unsigned int, in
669669
def __nvvm_redux_sync_and : NVPTXBuiltinSMAndPTX<"int(int, int)", SM_80, PTX70>;
670670
def __nvvm_redux_sync_xor : NVPTXBuiltinSMAndPTX<"int(int, int)", SM_80, PTX70>;
671671
def __nvvm_redux_sync_or : NVPTXBuiltinSMAndPTX<"int(int, int)", SM_80, PTX70>;
672+
def __nvvm_redux_sync_fmin : NVPTXBuiltinSMAndPTX<"float(float, int)", SM_100a, PTX86>;
673+
def __nvvm_redux_sync_fmin_abs : NVPTXBuiltinSMAndPTX<"float(float, int)", SM_100a, PTX86>;
674+
def __nvvm_redux_sync_fmin_NaN : NVPTXBuiltinSMAndPTX<"float(float, int)", SM_100a, PTX86>;
675+
def __nvvm_redux_sync_fmin_abs_NaN : NVPTXBuiltinSMAndPTX<"float(float, int)", SM_100a, PTX86>;
676+
def __nvvm_redux_sync_fmax : NVPTXBuiltinSMAndPTX<"float(float, int)", SM_100a, PTX86>;
677+
def __nvvm_redux_sync_fmax_abs : NVPTXBuiltinSMAndPTX<"float(float, int)", SM_100a, PTX86>;
678+
def __nvvm_redux_sync_fmax_NaN : NVPTXBuiltinSMAndPTX<"float(float, int)", SM_100a, PTX86>;
679+
def __nvvm_redux_sync_fmax_abs_NaN : NVPTXBuiltinSMAndPTX<"float(float, int)", SM_100a, PTX86>;
672680

673681
// Membar
674682

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %clang_cc1 "-triple" "nvptx-nvidia-cuda" "-target-feature" "+ptx86" "-target-cpu" "sm_100a" -emit-llvm -fcuda-is-device -o - %s | FileCheck %s
2+
// RUN: %clang_cc1 "-triple" "nvptx64-nvidia-cuda" "-target-feature" "+ptx86" "-target-cpu" "sm_100a" -emit-llvm -fcuda-is-device -o - %s | FileCheck %s
3+
4+
// CHECK: define{{.*}} void @_Z6kernelPf(ptr noundef %out_f)
5+
__attribute__((global)) void kernel(float* out_f) {
6+
float a = 3.0;
7+
int i = 0;
8+
9+
out_f[i++] = __nvvm_redux_sync_fmin(a, 0xFF);
10+
// CHECK: call contract float @llvm.nvvm.redux.sync.fmin
11+
12+
out_f[i++] = __nvvm_redux_sync_fmin_abs(a, 0xFF);
13+
// CHECK: call contract float @llvm.nvvm.redux.sync.fmin.abs
14+
15+
out_f[i++] = __nvvm_redux_sync_fmin_NaN(a, 0xF0);
16+
// CHECK: call contract float @llvm.nvvm.redux.sync.fmin.NaN
17+
18+
out_f[i++] = __nvvm_redux_sync_fmin_abs_NaN(a, 0x0F);
19+
// CHECK: call contract float @llvm.nvvm.redux.sync.fmin.abs.NaN
20+
21+
out_f[i++] = __nvvm_redux_sync_fmax(a, 0xFF);
22+
// CHECK: call contract float @llvm.nvvm.redux.sync.fmax
23+
24+
out_f[i++] = __nvvm_redux_sync_fmax_abs(a, 0x01);
25+
// CHECK: call contract float @llvm.nvvm.redux.sync.fmax.abs
26+
27+
out_f[i++] = __nvvm_redux_sync_fmax_NaN(a, 0xF1);
28+
// CHECK: call contract float @llvm.nvvm.redux.sync.fmax.NaN
29+
30+
out_f[i++] = __nvvm_redux_sync_fmax_abs_NaN(a, 0x10);
31+
// CHECK: call contract float @llvm.nvvm.redux.sync.fmax.abs.NaN
32+
33+
// CHECK: ret void
34+
}

llvm/include/llvm/IR/IntrinsicsNVVM.td

+12
Original file line numberDiff line numberDiff line change
@@ -4824,6 +4824,18 @@ def int_nvvm_redux_sync_xor : ClangBuiltin<"__nvvm_redux_sync_xor">,
48244824
def int_nvvm_redux_sync_or : ClangBuiltin<"__nvvm_redux_sync_or">,
48254825
Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
48264826
[IntrConvergent, IntrInaccessibleMemOnly, IntrNoCallback]>;
4827+
4828+
// redux.sync.op.{abs}.{NaN}.f32 dst, src, membermask;
4829+
foreach binOp = ["min", "max"] in {
4830+
foreach abs = ["", "_abs"] in {
4831+
foreach NaN = ["", "_NaN"] in {
4832+
def int_nvvm_redux_sync_f # binOp # abs # NaN :
4833+
ClangBuiltin<!strconcat("__nvvm_redux_sync_f", binOp, abs, NaN)>,
4834+
Intrinsic<[llvm_float_ty], [llvm_float_ty, llvm_i32_ty],
4835+
[IntrConvergent, IntrInaccessibleMemOnly, IntrNoCallback]>;
4836+
}
4837+
}
4838+
}
48274839

48284840
//
48294841
// WGMMA fence instructions

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

+19
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,25 @@ defm REDUX_SYNC_AND : REDUX_SYNC<"and", "b32", int_nvvm_redux_sync_and>;
328328
defm REDUX_SYNC_XOR : REDUX_SYNC<"xor", "b32", int_nvvm_redux_sync_xor>;
329329
defm REDUX_SYNC_OR : REDUX_SYNC<"or", "b32", int_nvvm_redux_sync_or>;
330330

331+
multiclass REDUX_SYNC_F<string BinOp, string abs, string NaN> {
332+
defvar intr_name = "int_nvvm_redux_sync_f" # BinOp # !subst(".", "_", abs) # !subst(".", "_", NaN);
333+
334+
def : NVPTXInst<(outs Float32Regs:$dst),
335+
(ins Float32Regs:$src, Int32Regs:$mask),
336+
"redux.sync." # BinOp # abs # NaN # ".f32 $dst, $src, $mask;",
337+
[(set f32:$dst, (!cast<Intrinsic>(intr_name) f32:$src, Int32Regs:$mask))]>,
338+
Requires<[hasPTX<86>, hasSM100a]>;
339+
}
340+
341+
defm REDUX_SYNC_FMIN : REDUX_SYNC_F<"min", "", "">;
342+
defm REDUX_SYNC_FMIN_ABS : REDUX_SYNC_F<"min", ".abs", "">;
343+
defm REDUX_SYNC_FMIN_NAN: REDUX_SYNC_F<"min", "", ".NaN">;
344+
defm REDUX_SYNC_FMIN_ABS_NAN: REDUX_SYNC_F<"min", ".abs", ".NaN">;
345+
defm REDUX_SYNC_FMAX : REDUX_SYNC_F<"max", "", "">;
346+
defm REDUX_SYNC_FMAX_ABS : REDUX_SYNC_F<"max", ".abs", "">;
347+
defm REDUX_SYNC_FMAX_NAN: REDUX_SYNC_F<"max", "", ".NaN">;
348+
defm REDUX_SYNC_FMAX_ABS_NAN: REDUX_SYNC_F<"max", ".abs", ".NaN">;
349+
331350
} // isConvergent = true
332351

333352
//-----------------------------------
+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 | FileCheck %s
3+
; RUN: %if ptxas-12.8 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 | %ptxas-verify -arch=sm_100a %}
4+
5+
declare float @llvm.nvvm.redux.sync.fmin(float, i32)
6+
define float @redux_sync_fmin(float %src, i32 %mask) {
7+
; CHECK-LABEL: redux_sync_fmin(
8+
; CHECK: {
9+
; CHECK-NEXT: .reg .b32 %r<2>;
10+
; CHECK-NEXT: .reg .f32 %f<3>;
11+
; CHECK-EMPTY:
12+
; CHECK-NEXT: // %bb.0:
13+
; CHECK-NEXT: ld.param.f32 %f1, [redux_sync_fmin_param_0];
14+
; CHECK-NEXT: ld.param.u32 %r1, [redux_sync_fmin_param_1];
15+
; CHECK-NEXT: redux.sync.min.f32 %f2, %f1, %r1;
16+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
17+
; CHECK-NEXT: ret;
18+
%val = call float @llvm.nvvm.redux.sync.fmin(float %src, i32 %mask)
19+
ret float %val
20+
}
21+
22+
declare float @llvm.nvvm.redux.sync.fmin.abs(float, i32)
23+
define float @redux_sync_fmin_abs(float %src, i32 %mask) {
24+
; CHECK-LABEL: redux_sync_fmin_abs(
25+
; CHECK: {
26+
; CHECK-NEXT: .reg .b32 %r<2>;
27+
; CHECK-NEXT: .reg .f32 %f<3>;
28+
; CHECK-EMPTY:
29+
; CHECK-NEXT: // %bb.0:
30+
; CHECK-NEXT: ld.param.f32 %f1, [redux_sync_fmin_abs_param_0];
31+
; CHECK-NEXT: ld.param.u32 %r1, [redux_sync_fmin_abs_param_1];
32+
; CHECK-NEXT: redux.sync.min.abs.f32 %f2, %f1, %r1;
33+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
34+
; CHECK-NEXT: ret;
35+
%val = call float @llvm.nvvm.redux.sync.fmin.abs(float %src, i32 %mask)
36+
ret float %val
37+
}
38+
39+
declare float @llvm.nvvm.redux.sync.fmin.NaN(float, i32)
40+
define float @redux_sync_fmin_NaN(float %src, i32 %mask) {
41+
; CHECK-LABEL: redux_sync_fmin_NaN(
42+
; CHECK: {
43+
; CHECK-NEXT: .reg .b32 %r<2>;
44+
; CHECK-NEXT: .reg .f32 %f<3>;
45+
; CHECK-EMPTY:
46+
; CHECK-NEXT: // %bb.0:
47+
; CHECK-NEXT: ld.param.f32 %f1, [redux_sync_fmin_NaN_param_0];
48+
; CHECK-NEXT: ld.param.u32 %r1, [redux_sync_fmin_NaN_param_1];
49+
; CHECK-NEXT: redux.sync.min.NaN.f32 %f2, %f1, %r1;
50+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
51+
; CHECK-NEXT: ret;
52+
%val = call float @llvm.nvvm.redux.sync.fmin.NaN(float %src, i32 %mask)
53+
ret float %val
54+
}
55+
56+
declare float @llvm.nvvm.redux.sync.fmin.abs.NaN(float, i32)
57+
define float @redux_sync_fmin_abs_NaN(float %src, i32 %mask) {
58+
; CHECK-LABEL: redux_sync_fmin_abs_NaN(
59+
; CHECK: {
60+
; CHECK-NEXT: .reg .b32 %r<2>;
61+
; CHECK-NEXT: .reg .f32 %f<3>;
62+
; CHECK-EMPTY:
63+
; CHECK-NEXT: // %bb.0:
64+
; CHECK-NEXT: ld.param.f32 %f1, [redux_sync_fmin_abs_NaN_param_0];
65+
; CHECK-NEXT: ld.param.u32 %r1, [redux_sync_fmin_abs_NaN_param_1];
66+
; CHECK-NEXT: redux.sync.min.abs.NaN.f32 %f2, %f1, %r1;
67+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
68+
; CHECK-NEXT: ret;
69+
%val = call float @llvm.nvvm.redux.sync.fmin.abs.NaN(float %src, i32 %mask)
70+
ret float %val
71+
}
72+
73+
declare float @llvm.nvvm.redux.sync.fmax(float, i32)
74+
define float @redux_sync_fmax(float %src, i32 %mask) {
75+
; CHECK-LABEL: redux_sync_fmax(
76+
; CHECK: {
77+
; CHECK-NEXT: .reg .b32 %r<2>;
78+
; CHECK-NEXT: .reg .f32 %f<3>;
79+
; CHECK-EMPTY:
80+
; CHECK-NEXT: // %bb.0:
81+
; CHECK-NEXT: ld.param.f32 %f1, [redux_sync_fmax_param_0];
82+
; CHECK-NEXT: ld.param.u32 %r1, [redux_sync_fmax_param_1];
83+
; CHECK-NEXT: redux.sync.max.f32 %f2, %f1, %r1;
84+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
85+
; CHECK-NEXT: ret;
86+
%val = call float @llvm.nvvm.redux.sync.fmax(float %src, i32 %mask)
87+
ret float %val
88+
}
89+
90+
declare float @llvm.nvvm.redux.sync.fmax.abs(float, i32)
91+
define float @redux_sync_fmax_abs(float %src, i32 %mask) {
92+
; CHECK-LABEL: redux_sync_fmax_abs(
93+
; CHECK: {
94+
; CHECK-NEXT: .reg .b32 %r<2>;
95+
; CHECK-NEXT: .reg .f32 %f<3>;
96+
; CHECK-EMPTY:
97+
; CHECK-NEXT: // %bb.0:
98+
; CHECK-NEXT: ld.param.f32 %f1, [redux_sync_fmax_abs_param_0];
99+
; CHECK-NEXT: ld.param.u32 %r1, [redux_sync_fmax_abs_param_1];
100+
; CHECK-NEXT: redux.sync.max.abs.f32 %f2, %f1, %r1;
101+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
102+
; CHECK-NEXT: ret;
103+
%val = call float @llvm.nvvm.redux.sync.fmax.abs(float %src, i32 %mask)
104+
ret float %val
105+
}
106+
107+
declare float @llvm.nvvm.redux.sync.fmax.NaN(float, i32)
108+
define float @redux_sync_fmax_NaN(float %src, i32 %mask) {
109+
; CHECK-LABEL: redux_sync_fmax_NaN(
110+
; CHECK: {
111+
; CHECK-NEXT: .reg .b32 %r<2>;
112+
; CHECK-NEXT: .reg .f32 %f<3>;
113+
; CHECK-EMPTY:
114+
; CHECK-NEXT: // %bb.0:
115+
; CHECK-NEXT: ld.param.f32 %f1, [redux_sync_fmax_NaN_param_0];
116+
; CHECK-NEXT: ld.param.u32 %r1, [redux_sync_fmax_NaN_param_1];
117+
; CHECK-NEXT: redux.sync.max.NaN.f32 %f2, %f1, %r1;
118+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
119+
; CHECK-NEXT: ret;
120+
%val = call float @llvm.nvvm.redux.sync.fmax.NaN(float %src, i32 %mask)
121+
ret float %val
122+
}
123+
124+
declare float @llvm.nvvm.redux.sync.fmax.abs.NaN(float, i32)
125+
define float @redux_sync_fmax_abs_NaN(float %src, i32 %mask) {
126+
; CHECK-LABEL: redux_sync_fmax_abs_NaN(
127+
; CHECK: {
128+
; CHECK-NEXT: .reg .b32 %r<2>;
129+
; CHECK-NEXT: .reg .f32 %f<3>;
130+
; CHECK-EMPTY:
131+
; CHECK-NEXT: // %bb.0:
132+
; CHECK-NEXT: ld.param.f32 %f1, [redux_sync_fmax_abs_NaN_param_0];
133+
; CHECK-NEXT: ld.param.u32 %r1, [redux_sync_fmax_abs_NaN_param_1];
134+
; CHECK-NEXT: redux.sync.max.abs.NaN.f32 %f2, %f1, %r1;
135+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
136+
; CHECK-NEXT: ret;
137+
%val = call float @llvm.nvvm.redux.sync.fmax.abs.NaN(float %src, i32 %mask)
138+
ret float %val
139+
}

0 commit comments

Comments
 (0)