Skip to content

Commit da17ced

Browse files
authored
[DirectX] Use scalar arguments for @llvm.dx.dot intrinsics (llvm#134570)
The `dx.dot2`, `dot3`, and `dot4` intrinsics exist purely to lower `dx.fdot`, and they map exactly to the DXIL ops of the same name. Using vectors for their arguments adds unnecessary complexity and causes us to have vector operations that are not trivial to lower post-scalarizer. Similarly, the `dx.dot2add` intrinsic is overly generic for something that only needs to lower to a single `dot2AddHalf` DXIL op. Update its signature to match the operation it lowers to. Fixes llvm#134569.
1 parent 3e64485 commit da17ced

File tree

13 files changed

+170
-151
lines changed

13 files changed

+170
-151
lines changed

clang/lib/CodeGen/TargetBuiltins/DirectX.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,17 @@ Value *CodeGenFunction::EmitDirectXBuiltinExpr(unsigned BuiltinID,
2525
case DirectX::BI__builtin_dx_dot2add: {
2626
Value *A = EmitScalarExpr(E->getArg(0));
2727
Value *B = EmitScalarExpr(E->getArg(1));
28-
Value *C = EmitScalarExpr(E->getArg(2));
28+
Value *Acc = EmitScalarExpr(E->getArg(2));
29+
30+
Value *AX = Builder.CreateExtractElement(A, Builder.getSize(0));
31+
Value *AY = Builder.CreateExtractElement(A, Builder.getSize(1));
32+
Value *BX = Builder.CreateExtractElement(B, Builder.getSize(0));
33+
Value *BY = Builder.CreateExtractElement(B, Builder.getSize(1));
2934

3035
Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
3136
return Builder.CreateIntrinsic(
32-
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
33-
"dx.dot2add");
37+
/*ReturnType=*/Acc->getType(), ID,
38+
ArrayRef<Value *>{Acc, AX, AY, BX, BY}, nullptr, "dx.dot2add");
3439
}
3540
}
3641
return nullptr;

clang/test/CodeGenDirectX/Builtins/dot2add.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@ typedef half half2 __attribute__((ext_vector_type(2)));
1717
// CHECK-NEXT: [[TMP0:%.*]] = load <2 x half>, ptr [[X_ADDR]], align 4
1818
// CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[Y_ADDR]], align 4
1919
// CHECK-NEXT: [[TMP2:%.*]] = load float, ptr [[Z_ADDR]], align 4
20-
// CHECK-NEXT: [[DX_DOT2ADD:%.*]] = call float @llvm.dx.dot2add.v2f16(<2 x half> [[TMP0]], <2 x half> [[TMP1]], float [[TMP2]])
20+
// CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x half> [[TMP0]], i32 0
21+
// CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x half> [[TMP0]], i32 1
22+
// CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x half> [[TMP1]], i32 0
23+
// CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x half> [[TMP1]], i32 1
24+
// CHECK-NEXT: [[DX_DOT2ADD:%.*]] = call float @llvm.dx.dot2add(float [[TMP2]], half [[TMP3]], half [[TMP4]], half [[TMP5]], half [[TMP6]])
2125
// CHECK-NEXT: ret float [[DX_DOT2ADD]]
2226
//
23-
float test_dot2add(half2 X, half2 Y, float Z) { return __builtin_dx_dot2add(X, Y, Z); }
27+
float test_dot2add(half2 X, half2 Y, float Z) {
28+
return __builtin_dx_dot2add(X, Y, Z);
29+
}

clang/test/CodeGenHLSL/builtins/dot2add.hlsl

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ float test_default_parameter_type(half2 p1, half2 p2, float p3) {
1313
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
1414
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
1515
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
16-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
16+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
17+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
18+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
19+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
20+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
1721
// CHECK: ret float %[[RES]]
1822
return dot2add(p1, p2, p3);
1923
}
@@ -25,7 +29,11 @@ float test_float_arg2_type(half2 p1, float2 p2, float p3) {
2529
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
2630
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
2731
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
28-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
32+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
33+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
34+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
35+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
36+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
2937
// CHECK: ret float %[[RES]]
3038
return dot2add(p1, p2, p3);
3139
}
@@ -37,7 +45,11 @@ float test_float_arg1_type(float2 p1, half2 p2, float p3) {
3745
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
3846
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
3947
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
40-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
48+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
49+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
50+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
51+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
52+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
4153
// CHECK: ret float %[[RES]]
4254
return dot2add(p1, p2, p3);
4355
}
@@ -49,7 +61,11 @@ float test_double_arg3_type(half2 p1, half2 p2, double p3) {
4961
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
5062
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
5163
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
52-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
64+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
65+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
66+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
67+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
68+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
5369
// CHECK: ret float %[[RES]]
5470
return dot2add(p1, p2, p3);
5571
}
@@ -62,7 +78,11 @@ float test_float_arg1_arg2_type(float2 p1, float2 p2, float p3) {
6278
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
6379
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
6480
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
65-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
81+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
82+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
83+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
84+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
85+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
6686
// CHECK: ret float %[[RES]]
6787
return dot2add(p1, p2, p3);
6888
}
@@ -75,7 +95,11 @@ float test_double_arg1_arg2_type(double2 p1, double2 p2, float p3) {
7595
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
7696
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
7797
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
78-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
98+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
99+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
100+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
101+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
102+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
79103
// CHECK: ret float %[[RES]]
80104
return dot2add(p1, p2, p3);
81105
}
@@ -88,7 +112,11 @@ float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, float p3) {
88112
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
89113
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
90114
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
91-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
115+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
116+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
117+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
118+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
119+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
92120
// CHECK: ret float %[[RES]]
93121
return dot2add(p1, p2, p3);
94122
}
@@ -101,7 +129,11 @@ float test_int32_arg1_arg2_type(int32_t2 p1, int32_t2 p2, float p3) {
101129
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
102130
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
103131
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
104-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
132+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
133+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
134+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
135+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
136+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
105137
// CHECK: ret float %[[RES]]
106138
return dot2add(p1, p2, p3);
107139
}
@@ -114,7 +146,11 @@ float test_int64_arg1_arg2_type(int64_t2 p1, int64_t2 p2, float p3) {
114146
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
115147
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
116148
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
117-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
149+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
150+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
151+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
152+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
153+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
118154
// CHECK: ret float %[[RES]]
119155
return dot2add(p1, p2, p3);
120156
}
@@ -129,7 +165,11 @@ float test_bool_arg1_arg2_type(bool2 p1, bool2 p2, float p3) {
129165
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
130166
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
131167
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
132-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
168+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
169+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
170+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
171+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
172+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
133173
// CHECK: ret float %[[RES]]
134174
return dot2add(p1, p2, p3);
135175
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,27 @@ def int_dx_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>,
7676
def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
7777
def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
7878

79-
def int_dx_dot2 :
80-
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
81-
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
82-
[IntrNoMem, Commutative] >;
83-
def int_dx_dot3 :
84-
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
85-
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
86-
[IntrNoMem, Commutative] >;
87-
def int_dx_dot4 :
88-
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
89-
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
90-
[IntrNoMem, Commutative] >;
79+
def int_dx_dot2 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
80+
[
81+
llvm_anyfloat_ty, LLVMMatchType<0>,
82+
LLVMMatchType<0>, LLVMMatchType<0>
83+
],
84+
[IntrNoMem, Commutative]>;
85+
def int_dx_dot3 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
86+
[
87+
llvm_anyfloat_ty, LLVMMatchType<0>,
88+
LLVMMatchType<0>, LLVMMatchType<0>,
89+
LLVMMatchType<0>, LLVMMatchType<0>
90+
],
91+
[IntrNoMem, Commutative]>;
92+
def int_dx_dot4 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
93+
[
94+
llvm_anyfloat_ty, LLVMMatchType<0>,
95+
LLVMMatchType<0>, LLVMMatchType<0>,
96+
LLVMMatchType<0>, LLVMMatchType<0>,
97+
LLVMMatchType<0>, LLVMMatchType<0>
98+
],
99+
[IntrNoMem, Commutative]>;
91100
def int_dx_fdot :
92101
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
93102
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
@@ -100,9 +109,9 @@ def int_dx_udot :
100109
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
101110
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
102111
[IntrNoMem, Commutative] >;
103-
def int_dx_dot2add :
104-
DefaultAttrsIntrinsic<[llvm_float_ty],
105-
[llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
112+
def int_dx_dot2add :
113+
DefaultAttrsIntrinsic<[llvm_float_ty],
114+
[llvm_float_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty],
106115
[IntrNoMem, Commutative]>;
107116
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
108117
def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,8 +1078,7 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
10781078
}
10791079

10801080
def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
1081-
let Doc = "dot product of 2 vectors of half having size = 2, returns "
1082-
"float";
1081+
let Doc = "2D half dot product with accumulate to float";
10831082
let intrinsics = [IntrinSelect<int_dx_dot2add>];
10841083
let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
10851084
let result = FloatTy;

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
185185
assert(ATy->getScalarType()->isFloatingPointTy());
186186

187187
Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
188-
switch (AVec->getNumElements()) {
188+
int NumElts = AVec->getNumElements();
189+
switch (NumElts) {
189190
case 2:
190191
DotIntrinsic = Intrinsic::dx_dot2;
191192
break;
@@ -201,8 +202,14 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
201202
/* gen_crash_diag=*/false);
202203
return nullptr;
203204
}
204-
return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
205-
ArrayRef<Value *>{A, B}, nullptr, "dot");
205+
206+
SmallVector<Value *> Args;
207+
for (int I = 0; I < NumElts; ++I)
208+
Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I)));
209+
for (int I = 0; I < NumElts; ++I)
210+
Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I)));
211+
return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args,
212+
nullptr, "dot");
206213
}
207214

208215
// Create the appropriate DXIL float dot intrinsic for the operands of Orig

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -33,52 +33,6 @@
3333
using namespace llvm;
3434
using namespace llvm::dxil;
3535

36-
static bool isVectorArgExpansion(Function &F) {
37-
switch (F.getIntrinsicID()) {
38-
case Intrinsic::dx_dot2:
39-
case Intrinsic::dx_dot3:
40-
case Intrinsic::dx_dot4:
41-
return true;
42-
}
43-
return false;
44-
}
45-
46-
static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
47-
SmallVector<Value *> ExtractedElements;
48-
auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
49-
for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
50-
Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
51-
Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
52-
ExtractedElements.push_back(ExtractedElement);
53-
}
54-
return ExtractedElements;
55-
}
56-
57-
static SmallVector<Value *>
58-
argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder, unsigned NumOperands) {
59-
assert(NumOperands > 0);
60-
Value *Arg0 = Orig->getOperand(0);
61-
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
62-
assert(VecArg0);
63-
SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
64-
for (unsigned I = 1; I < NumOperands; ++I) {
65-
Value *Arg = Orig->getOperand(I);
66-
[[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
67-
assert(VecArg);
68-
assert(VecArg0->getElementType() == VecArg->getElementType());
69-
assert(VecArg0->getNumElements() == VecArg->getNumElements());
70-
auto NextOperandList = populateOperands(Arg, Builder);
71-
NewOperands.append(NextOperandList.begin(), NextOperandList.end());
72-
}
73-
return NewOperands;
74-
}
75-
76-
static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
77-
IRBuilder<> &Builder) {
78-
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
79-
return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
80-
}
81-
8236
namespace {
8337
class OpLowerer {
8438
Module &M;
@@ -150,9 +104,6 @@ class OpLowerer {
150104
[[nodiscard]] bool
151105
replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
152106
ArrayRef<IntrinArgSelect> ArgSelects) {
153-
bool IsVectorArgExpansion = isVectorArgExpansion(F);
154-
assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
155-
"Cann't do vector arg expansion when using arg selects.");
156107
return replaceFunction(F, [&](CallInst *CI) -> Error {
157108
OpBuilder.getIRB().SetInsertPoint(CI);
158109
SmallVector<Value *> Args;
@@ -170,15 +121,6 @@ class OpLowerer {
170121
break;
171122
}
172123
}
173-
} else if (IsVectorArgExpansion) {
174-
Args = argVectorFlatten(CI, OpBuilder.getIRB());
175-
} else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
176-
// arg[NumOperands-1] is a pointer and is not needed by our flattening.
177-
// arg[NumOperands-2] also does not need to be flattened because it is a
178-
// scalar.
179-
unsigned NumOperands = CI->getNumOperands() - 2;
180-
Args.push_back(CI->getArgOperand(NumOperands));
181-
Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
182124
} else {
183125
Args.append(CI->arg_begin(), CI->arg_end());
184126
}

llvm/test/CodeGen/DirectX/dot2_error.ll

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
; CHECK: in function dot_double2
55
; CHECK-SAME: Cannot create Dot2 operation: Invalid overload type
66

7-
define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) {
7+
define noundef double @dot_double2(double noundef %a1, double noundef %a2,
8+
double noundef %b1, double noundef %b2) {
89
entry:
9-
%dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b)
10+
%dx.dot = call double @llvm.dx.dot2(double %a1, double %a2, double %b1, double %b2)
1011
ret double %dx.dot
1112
}

0 commit comments

Comments
 (0)