Skip to content

Commit cdaebf6

Browse files
[NVPTX] Fix crash caused by ComputePTXValueVTs (#104524)
When [lowering return values](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3422) from LLVM IR to SelectionDAG, we check that [the number of values `SelectionDAG` tells us to return is equal to the number of values that `ComputePTXValueVTs()` tells us to return](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L3441). However, this check can fail on valid IR. For example: ``` define <6 x half> @foo() { ret <6 x half> zeroinitializer } ``` `ComputePTXValueVTs()` tells us to return ***3*** `v2f16` values, while `SelectionDAG` tells us to return ***6*** `f16` values. Thus, the compiler will crash. `ComputePTXValueVTs()` [supports all `half` element vectors with an even number of elements](https://github.com/llvm/llvm-project/blob/99a10f1fe8a7e4b0fdb4c6dd5e7f24f87e0d3695/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L213). Whereas `SelectionDAG` [only supports power-of-2 sized vectors](https://github.com/llvm/llvm-project/blob/4e078e3797098daa40d254447c499bcf61415308/llvm/lib/CodeGen/TargetLoweringBase.cpp#L1580). This is the root of the discrepancy. Assuming that the developers who added the code to `ComputePTXValueVTs()` overlooked this, I've restricted `ComputePTXValueVTs()` to compute the same number of return values as `SelectionDAG`, instead of extending `SelectionDAG` to support non-power-of-2 sized vectors.
1 parent 46fe36a commit cdaebf6

File tree

3 files changed

+592
-5
lines changed

3 files changed

+592
-5
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,15 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
207207
if (VT.isVector()) {
208208
unsigned NumElts = VT.getVectorNumElements();
209209
EVT EltVT = VT.getVectorElementType();
210-
// Vectors with an even number of f16 elements will be passed to
211-
// us as an array of v2f16/v2bf16 elements. We must match this so we
212-
// stay in sync with Ins/Outs.
213-
if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
210+
// We require power-of-2 sized vectors becuase
211+
// TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
212+
// ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
213+
// vectors.
214+
if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
215+
isPowerOf2_32(NumElts)) {
216+
// Vectors with an even number of f16 elements will be passed to
217+
// us as an array of v2f16/v2bf16 elements. We must match this so we
218+
// stay in sync with Ins/Outs.
214219
switch (EltVT.getSimpleVT().SimpleTy) {
215220
case MVT::f16:
216221
EltVT = MVT::v2f16;
@@ -226,7 +231,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
226231
}
227232
NumElts /= 2;
228233
} else if (EltVT.getSimpleVT() == MVT::i8 &&
229-
(NumElts % 4 == 0 || NumElts == 3)) {
234+
((NumElts % 4 == 0 && isPowerOf2_32(NumElts)) ||
235+
NumElts == 3)) {
230236
// v*i8 are formally lowered as v4i8
231237
EltVT = MVT::v4i8;
232238
NumElts = (NumElts + 3) / 4;
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s
3+
4+
target triple = "nvptx-nvidia-cuda"
5+
6+
define <6 x half> @half6() {
7+
; CHECK-LABEL: half6(
8+
; CHECK: {
9+
; CHECK-NEXT: .reg .b16 %rs<2>;
10+
; CHECK-EMPTY:
11+
; CHECK-NEXT: // %bb.0:
12+
; CHECK-NEXT: mov.b16 %rs1, 0x0000;
13+
; CHECK-NEXT: st.param.v4.b16 [func_retval0+0], {%rs1, %rs1, %rs1, %rs1};
14+
; CHECK-NEXT: st.param.v2.b16 [func_retval0+8], {%rs1, %rs1};
15+
; CHECK-NEXT: ret;
16+
ret <6 x half> zeroinitializer
17+
}
18+
19+
define <10 x half> @half10() {
20+
; CHECK-LABEL: half10(
21+
; CHECK: {
22+
; CHECK-NEXT: .reg .b16 %rs<2>;
23+
; CHECK-EMPTY:
24+
; CHECK-NEXT: // %bb.0:
25+
; CHECK-NEXT: mov.b16 %rs1, 0x0000;
26+
; CHECK-NEXT: st.param.v4.b16 [func_retval0+0], {%rs1, %rs1, %rs1, %rs1};
27+
; CHECK-NEXT: st.param.v4.b16 [func_retval0+8], {%rs1, %rs1, %rs1, %rs1};
28+
; CHECK-NEXT: st.param.v2.b16 [func_retval0+16], {%rs1, %rs1};
29+
; CHECK-NEXT: ret;
30+
ret <10 x half> zeroinitializer
31+
}
32+
33+
define <12 x i8> @byte12() {
34+
; CHECK-LABEL: byte12(
35+
; CHECK: {
36+
; CHECK-NEXT: .reg .b16 %rs<2>;
37+
; CHECK-EMPTY:
38+
; CHECK-NEXT: // %bb.0:
39+
; CHECK-NEXT: mov.u16 %rs1, 0;
40+
; CHECK-NEXT: st.param.v4.b8 [func_retval0+0], {%rs1, %rs1, %rs1, %rs1};
41+
; CHECK-NEXT: st.param.v4.b8 [func_retval0+4], {%rs1, %rs1, %rs1, %rs1};
42+
; CHECK-NEXT: st.param.v4.b8 [func_retval0+8], {%rs1, %rs1, %rs1, %rs1};
43+
; CHECK-NEXT: ret;
44+
ret <12 x i8> zeroinitializer
45+
}
46+
47+
define <20 x i8> @byte20() {
48+
; CHECK-LABEL: byte20(
49+
; CHECK: {
50+
; CHECK-NEXT: .reg .b16 %rs<2>;
51+
; CHECK-EMPTY:
52+
; CHECK-NEXT: // %bb.0:
53+
; CHECK-NEXT: mov.u16 %rs1, 0;
54+
; CHECK-NEXT: st.param.v4.b8 [func_retval0+0], {%rs1, %rs1, %rs1, %rs1};
55+
; CHECK-NEXT: st.param.v4.b8 [func_retval0+4], {%rs1, %rs1, %rs1, %rs1};
56+
; CHECK-NEXT: st.param.v4.b8 [func_retval0+8], {%rs1, %rs1, %rs1, %rs1};
57+
; CHECK-NEXT: st.param.v4.b8 [func_retval0+12], {%rs1, %rs1, %rs1, %rs1};
58+
; CHECK-NEXT: st.param.v4.b8 [func_retval0+16], {%rs1, %rs1, %rs1, %rs1};
59+
; CHECK-NEXT: ret;
60+
ret <20 x i8> zeroinitializer
61+
}

0 commit comments

Comments
 (0)