Skip to content

Commit 196ee23

Browse files
authored
[Clang] Correctly enable the f16 type for offloading (llvm#98331)
Summary: There's an extra argument that's required to *actually* enable f16 usage. For whatever reason there's a difference between fp16 and f16, where fp16 is some weird version that converts between the two. Long story short, without this the math builtins are blatantly broken.
1 parent a972b2e commit 196ee23

File tree

3 files changed

+119
-119
lines changed

3 files changed

+119
-119
lines changed

clang/lib/Basic/Targets/NVPTX.h

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
7575

7676
ArrayRef<Builtin::Info> getTargetBuiltins() const override;
7777

78+
bool useFP16ConversionIntrinsics() const override { return false; }
79+
7880
bool
7981
initFeatureMap(llvm::StringMap<bool> &Features, DiagnosticsEngine &Diags,
8082
StringRef CPU,

clang/test/CodeGen/builtins-nvptx-native-half-type-err.c

-119
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// REQUIRES: nvptx-registered-target
2+
//
3+
// RUN: %clang_cc1 -ffp-contract=off -triple nvptx-unknown-unknown -target-cpu \
4+
// RUN: sm_86 -target-feature +ptx72 -fcuda-is-device -x cuda -emit-llvm -o - %s \
5+
// RUN: | FileCheck %s
6+
7+
#define __device__ __attribute__((device))
8+
typedef __fp16 __fp16v2 __attribute__((ext_vector_type(2)));
9+
10+
// CHECK: call half @llvm.nvvm.ex2.approx.f16(half {{.*}})
11+
// CHECK: call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> {{.*}})
12+
// CHECK: call half @llvm.nvvm.fma.rn.relu.f16(half {{.*}}, half {{.*}}, half {{.*}})
13+
// CHECK: call half @llvm.nvvm.fma.rn.ftz.relu.f16(half {{.*}}, half {{.*}}, half {{.*}})
14+
// CHECK: call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
15+
// CHECK: call <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
16+
// CHECK: call half @llvm.nvvm.fma.rn.ftz.f16(half {{.*}}, half {{.*}}, half {{.*}})
17+
// CHECK: call half @llvm.nvvm.fma.rn.sat.f16(half {{.*}}, half {{.*}}, half {{.*}})
18+
// CHECK: call half @llvm.nvvm.fma.rn.ftz.sat.f16(half {{.*}}, half {{.*}}, half {{.*}})
19+
// CHECK: call <2 x half> @llvm.nvvm.fma.rn.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
20+
// CHECK: call <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
21+
// CHECK: call <2 x half> @llvm.nvvm.fma.rn.sat.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
22+
// CHECK: call <2 x half> @llvm.nvvm.fma.rn.ftz.sat.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
23+
// CHECK: call half @llvm.nvvm.fmin.f16(half {{.*}}, half {{.*}})
24+
// CHECK: call half @llvm.nvvm.fmin.ftz.f16(half {{.*}}, half {{.*}})
25+
// CHECK: call half @llvm.nvvm.fmin.nan.f16(half {{.*}}, half {{.*}})
26+
// CHECK: call half @llvm.nvvm.fmin.ftz.nan.f16(half {{.*}}, half {{.*}})
27+
// CHECK: call <2 x half> @llvm.nvvm.fmin.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
28+
// CHECK: call <2 x half> @llvm.nvvm.fmin.ftz.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
29+
// CHECK: call <2 x half> @llvm.nvvm.fmin.nan.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
30+
// CHECK: call <2 x half> @llvm.nvvm.fmin.ftz.nan.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
31+
// CHECK: call half @llvm.nvvm.fmin.xorsign.abs.f16(half {{.*}}, half {{.*}})
32+
// CHECK: call half @llvm.nvvm.fmin.ftz.xorsign.abs.f16(half {{.*}}, half {{.*}})
33+
// CHECK: call half @llvm.nvvm.fmin.nan.xorsign.abs.f16(half {{.*}}, half {{.*}})
34+
// CHECK: call half @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16(half {{.*}}, half {{.*}})
35+
// CHECK: call <2 x half> @llvm.nvvm.fmin.xorsign.abs.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
36+
// CHECK: call <2 x half> @llvm.nvvm.fmin.ftz.xorsign.abs.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
37+
// CHECK: call <2 x half> @llvm.nvvm.fmin.nan.xorsign.abs.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
38+
// CHECK: call <2 x half> @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
39+
// CHECK: call half @llvm.nvvm.fmax.f16(half {{.*}}, half {{.*}})
40+
// CHECK: call half @llvm.nvvm.fmax.ftz.f16(half {{.*}}, half {{.*}})
41+
// CHECK: call half @llvm.nvvm.fmax.nan.f16(half {{.*}}, half {{.*}})
42+
// CHECK: call half @llvm.nvvm.fmax.ftz.nan.f16(half {{.*}}, half {{.*}})
43+
// CHECK: call <2 x half> @llvm.nvvm.fmax.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
44+
// CHECK: call <2 x half> @llvm.nvvm.fmax.ftz.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
45+
// CHECK: call <2 x half> @llvm.nvvm.fmax.nan.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
46+
// CHECK: call <2 x half> @llvm.nvvm.fmax.ftz.nan.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
47+
// CHECK: call half @llvm.nvvm.fmax.xorsign.abs.f16(half {{.*}}, half {{.*}})
48+
// CHECK: call half @llvm.nvvm.fmax.ftz.xorsign.abs.f16(half {{.*}}, half {{.*}})
49+
// CHECK: call half @llvm.nvvm.fmax.nan.xorsign.abs.f16(half {{.*}}, half {{.*}})
50+
// CHECK: call half @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16(half {{.*}}, half {{.*}})
51+
// CHECK: call <2 x half> @llvm.nvvm.fmax.xorsign.abs.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
52+
// CHECK: call <2 x half> @llvm.nvvm.fmax.ftz.xorsign.abs.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
53+
// CHECK: call <2 x half> @llvm.nvvm.fmax.nan.xorsign.abs.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
54+
// CHECK: call <2 x half> @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}})
55+
// CHECK: call half @llvm.nvvm.ldg.global.f.f16.p0(ptr {{.*}}, i32 2)
56+
// CHECK: call <2 x half> @llvm.nvvm.ldg.global.f.v2f16.p0(ptr {{.*}}, i32 4)
57+
// CHECK: call half @llvm.nvvm.ldu.global.f.f16.p0(ptr {{.*}}, i32 2)
58+
// CHECK: call <2 x half> @llvm.nvvm.ldu.global.f.v2f16.p0(ptr {{.*}}, i32 4)
59+
__device__ void nvvm_native_half_types(void *a, void*b, void*c, __fp16* out) {
60+
__fp16v2 resv2 = {0, 0};
61+
*out += __nvvm_ex2_approx_f16(*(__fp16 *)a);
62+
resv2 = __nvvm_ex2_approx_f16x2(*(__fp16v2*)a);
63+
64+
*out += __nvvm_fma_rn_relu_f16(*(__fp16*)a, *(__fp16*)b, *(__fp16*)c);
65+
*out += __nvvm_fma_rn_ftz_relu_f16(*(__fp16*)a, *(__fp16*)b, *(__fp16 *)c);
66+
resv2 += __nvvm_fma_rn_relu_f16x2(*(__fp16v2*)a, *(__fp16v2*)b, *(__fp16v2*)c);
67+
resv2 += __nvvm_fma_rn_ftz_relu_f16x2(*(__fp16v2*)a, *(__fp16v2*)b, *(__fp16v2*)c);
68+
*out += __nvvm_fma_rn_ftz_f16(*(__fp16*)a, *(__fp16*)b, *(__fp16*)c);
69+
*out += __nvvm_fma_rn_sat_f16(*(__fp16*)a, *(__fp16*)b, *(__fp16*)c);
70+
*out += __nvvm_fma_rn_ftz_sat_f16(*(__fp16*)a, *(__fp16*)b, *(__fp16*)c);
71+
resv2 += __nvvm_fma_rn_f16x2(*(__fp16v2*)a, *(__fp16v2*)b, *(__fp16v2*)c);
72+
resv2 += __nvvm_fma_rn_ftz_f16x2(*(__fp16v2*)a, *(__fp16v2*)b, *(__fp16v2*)c);
73+
resv2 += __nvvm_fma_rn_sat_f16x2(*(__fp16v2*)a, *(__fp16v2*)b, *(__fp16v2*)c);
74+
resv2 += __nvvm_fma_rn_ftz_sat_f16x2(*(__fp16v2*)a, *(__fp16v2*)b, *(__fp16v2*)c);
75+
76+
*out += __nvvm_fmin_f16(*(__fp16*)a, *(__fp16*)b);
77+
*out += __nvvm_fmin_ftz_f16(*(__fp16*)a, *(__fp16*)b);
78+
*out += __nvvm_fmin_nan_f16(*(__fp16*)a, *(__fp16*)b);
79+
*out += __nvvm_fmin_ftz_nan_f16(*(__fp16*)a, *(__fp16*)b);
80+
resv2 += __nvvm_fmin_f16x2(*(__fp16v2*)a , *(__fp16v2*)b);
81+
resv2 += __nvvm_fmin_ftz_f16x2(*(__fp16v2*)a , *(__fp16v2*)b);
82+
resv2 += __nvvm_fmin_nan_f16x2(*(__fp16v2*)a , *(__fp16v2*)b);
83+
resv2 += __nvvm_fmin_ftz_nan_f16x2(*(__fp16v2*)a , *(__fp16v2*)b);
84+
*out += __nvvm_fmin_xorsign_abs_f16(*(__fp16*)a, *(__fp16*)b);
85+
*out += __nvvm_fmin_ftz_xorsign_abs_f16(*(__fp16*)a, *(__fp16*)b);
86+
*out += __nvvm_fmin_nan_xorsign_abs_f16(*(__fp16*)a, *(__fp16*)b);
87+
*out += __nvvm_fmin_ftz_nan_xorsign_abs_f16(*(__fp16*)a, *(__fp16*)b);
88+
resv2 += __nvvm_fmin_xorsign_abs_f16x2(*(__fp16v2*)a, *(__fp16v2*)b);
89+
resv2 += __nvvm_fmin_ftz_xorsign_abs_f16x2(*(__fp16v2*)a, *(__fp16v2*)b);
90+
resv2 += __nvvm_fmin_nan_xorsign_abs_f16x2(*(__fp16v2*)a, *(__fp16v2*)b);
91+
resv2 += __nvvm_fmin_ftz_nan_xorsign_abs_f16x2(*(__fp16v2*)a, *(__fp16v2*)b);
92+
93+
*out += __nvvm_fmax_f16(*(__fp16*)a, *(__fp16*)b);
94+
*out += __nvvm_fmax_ftz_f16(*(__fp16*)a, *(__fp16*)b);
95+
*out += __nvvm_fmax_nan_f16(*(__fp16*)a, *(__fp16*)b);
96+
*out += __nvvm_fmax_ftz_nan_f16(*(__fp16*)a, *(__fp16*)b);
97+
resv2 += __nvvm_fmax_f16x2(*(__fp16v2*)a , *(__fp16v2*)b);
98+
resv2 += __nvvm_fmax_ftz_f16x2(*(__fp16v2*)a , *(__fp16v2*)b);
99+
resv2 += __nvvm_fmax_nan_f16x2(*(__fp16v2*)a , *(__fp16v2*)b);
100+
resv2 += __nvvm_fmax_ftz_nan_f16x2(*(__fp16v2*)a , *(__fp16v2*)b);
101+
*out += __nvvm_fmax_xorsign_abs_f16(*(__fp16*)a, *(__fp16*)b);
102+
*out += __nvvm_fmax_ftz_xorsign_abs_f16(*(__fp16*)a, *(__fp16*)b);
103+
*out += __nvvm_fmax_nan_xorsign_abs_f16(*(__fp16*)a, *(__fp16*)b);
104+
*out += __nvvm_fmax_ftz_nan_xorsign_abs_f16(*(__fp16*)a, *(__fp16*)b);
105+
resv2 += __nvvm_fmax_xorsign_abs_f16x2(*(__fp16v2*)a, *(__fp16v2*)b);
106+
resv2 += __nvvm_fmax_ftz_xorsign_abs_f16x2(*(__fp16v2*)a, *(__fp16v2*)b);
107+
resv2 += __nvvm_fmax_nan_xorsign_abs_f16x2(*(__fp16v2*)a, *(__fp16v2*)b);
108+
resv2 += __nvvm_fmax_ftz_nan_xorsign_abs_f16x2(*(__fp16v2*)a, *(__fp16v2*)b);
109+
110+
*out += __nvvm_ldg_h((__fp16 *)a);
111+
resv2 += __nvvm_ldg_h2((__fp16v2 *)a);
112+
113+
*out += __nvvm_ldu_h((__fp16 *)a);
114+
resv2 += __nvvm_ldu_h2((__fp16v2 *)a);
115+
116+
*out += resv2[0] + resv2[1];
117+
}

0 commit comments

Comments
 (0)