Skip to content

Commit 8392bf6

Browse files
calebzulawskitlively
authored andcommitted
Improve WebAssembly vector bitmask, mask reduction, and extending
This is inspired by a recently filed Rust issue noting poor codegen for vector masks (rust-lang/portable-simd#351). Reviewed By: tlively Differential Revision: https://reviews.llvm.org/D151782
1 parent 867ee3b commit 8392bf6

File tree

5 files changed

+259
-99
lines changed

5 files changed

+259
-99
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

+140
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
157157

158158
// SIMD-specific configuration
159159
if (Subtarget->hasSIMD128()) {
160+
// Combine vector mask reductions into alltrue/anytrue
161+
setTargetDAGCombine(ISD::SETCC);
162+
163+
// Convert vector to integer bitcasts to bitmask
164+
setTargetDAGCombine(ISD::BITCAST);
165+
160166
// Hoist bitcasts out of shuffles
161167
setTargetDAGCombine(ISD::VECTOR_SHUFFLE);
162168

@@ -258,6 +264,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
258264
// But saturating fp_to_int converstions are
259265
for (auto Op : {ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT})
260266
setOperationAction(Op, MVT::v4i32, Custom);
267+
268+
// Support vector extending
269+
for (auto T : MVT::integer_fixedlen_vector_valuetypes()) {
270+
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
271+
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
272+
}
261273
}
262274

263275
// As a special case, these operators use the type to mean the type to
@@ -1374,6 +1386,11 @@ void WebAssemblyTargetLowering::ReplaceNodeResults(
13741386
// SIGN_EXTEND_INREG, but for non-vector sign extends the result might be an
13751387
// illegal type.
13761388
break;
1389+
case ISD::SIGN_EXTEND_VECTOR_INREG:
1390+
case ISD::ZERO_EXTEND_VECTOR_INREG:
1391+
// Do not add any results, signifying that N should not be custom lowered.
1392+
// EXTEND_VECTOR_INREG is implemented for some vectors, but not all.
1393+
break;
13771394
default:
13781395
llvm_unreachable(
13791396
"ReplaceNodeResults not implemented for this op for WebAssembly!");
@@ -1424,6 +1441,9 @@ SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op,
14241441
return LowerIntrinsic(Op, DAG);
14251442
case ISD::SIGN_EXTEND_INREG:
14261443
return LowerSIGN_EXTEND_INREG(Op, DAG);
1444+
case ISD::ZERO_EXTEND_VECTOR_INREG:
1445+
case ISD::SIGN_EXTEND_VECTOR_INREG:
1446+
return LowerEXTEND_VECTOR_INREG(Op, DAG);
14271447
case ISD::BUILD_VECTOR:
14281448
return LowerBUILD_VECTOR(Op, DAG);
14291449
case ISD::VECTOR_SHUFFLE:
@@ -1877,6 +1897,48 @@ WebAssemblyTargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
18771897
Op.getOperand(1));
18781898
}
18791899

1900+
SDValue
1901+
WebAssemblyTargetLowering::LowerEXTEND_VECTOR_INREG(SDValue Op,
1902+
SelectionDAG &DAG) const {
1903+
SDLoc DL(Op);
1904+
EVT VT = Op.getValueType();
1905+
SDValue Src = Op.getOperand(0);
1906+
EVT SrcVT = Src.getValueType();
1907+
1908+
if (SrcVT.getVectorElementType() == MVT::i1 ||
1909+
SrcVT.getVectorElementType() == MVT::i64)
1910+
return SDValue();
1911+
1912+
assert(VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits() == 0 &&
1913+
"Unexpected extension factor.");
1914+
unsigned Scale = VT.getScalarSizeInBits() / SrcVT.getScalarSizeInBits();
1915+
1916+
if (Scale != 2 && Scale != 4 && Scale != 8)
1917+
return SDValue();
1918+
1919+
unsigned Ext;
1920+
switch (Op.getOpcode()) {
1921+
case ISD::ZERO_EXTEND_VECTOR_INREG:
1922+
Ext = WebAssemblyISD::EXTEND_LOW_U;
1923+
break;
1924+
case ISD::SIGN_EXTEND_VECTOR_INREG:
1925+
Ext = WebAssemblyISD::EXTEND_LOW_S;
1926+
break;
1927+
}
1928+
1929+
SDValue Ret = Src;
1930+
while (Scale != 1) {
1931+
Ret = DAG.getNode(Ext, DL,
1932+
Ret.getValueType()
1933+
.widenIntegerVectorElementType(*DAG.getContext())
1934+
.getHalfNumVectorElementsVT(*DAG.getContext()),
1935+
Ret);
1936+
Scale /= 2;
1937+
}
1938+
assert(Ret.getValueType() == VT);
1939+
return Ret;
1940+
}
1941+
18801942
static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
18811943
SDLoc DL(Op);
18821944
if (Op.getValueType() != MVT::v2f64)
@@ -2692,12 +2754,90 @@ static SDValue performTruncateCombine(SDNode *N,
26922754
return truncateVectorWithNARROW(OutVT, In, DL, DAG);
26932755
}
26942756

2757+
static SDValue performBitcastCombine(SDNode *N,
2758+
TargetLowering::DAGCombinerInfo &DCI) {
2759+
auto &DAG = DCI.DAG;
2760+
SDLoc DL(N);
2761+
SDValue Src = N->getOperand(0);
2762+
EVT VT = N->getValueType(0);
2763+
EVT SrcVT = Src.getValueType();
2764+
2765+
// bitcast <N x i1> to iN
2766+
// ==> bitmask
2767+
if (DCI.isBeforeLegalize() && VT.isScalarInteger() &&
2768+
SrcVT.isFixedLengthVector() && SrcVT.getScalarType() == MVT::i1) {
2769+
unsigned NumElts = SrcVT.getVectorNumElements();
2770+
assert(NumElts == 2 || NumElts == 4 || NumElts == 8 || NumElts == 16);
2771+
EVT Width = MVT::getIntegerVT(128 / NumElts);
2772+
return DAG.getZExtOrTrunc(
2773+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
2774+
{DAG.getConstant(Intrinsic::wasm_bitmask, DL, MVT::i32),
2775+
DAG.getSExtOrTrunc(N->getOperand(0), DL,
2776+
SrcVT.changeVectorElementType(Width))}),
2777+
DL, VT);
2778+
}
2779+
2780+
return SDValue();
2781+
}
2782+
2783+
static SDValue performSETCCCombine(SDNode *N,
2784+
TargetLowering::DAGCombinerInfo &DCI) {
2785+
auto &DAG = DCI.DAG;
2786+
2787+
SDValue LHS = N->getOperand(0);
2788+
SDValue RHS = N->getOperand(1);
2789+
ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
2790+
SDLoc DL(N);
2791+
EVT VT = N->getValueType(0);
2792+
2793+
// setcc (iN (bitcast (vNi1 X))), 0, ne
2794+
// ==> any_true (vNi1 X)
2795+
// setcc (iN (bitcast (vNi1 X))), 0, eq
2796+
// ==> xor (any_true (vNi1 X)), -1
2797+
// setcc (iN (bitcast (vNi1 X))), -1, eq
2798+
// ==> all_true (vNi1 X)
2799+
// setcc (iN (bitcast (vNi1 X))), -1, ne
2800+
// ==> xor (all_true (vNi1 X)), -1
2801+
if (DCI.isBeforeLegalize() && VT.isScalarInteger() &&
2802+
(Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
2803+
(isNullConstant(RHS) || isAllOnesConstant(RHS)) &&
2804+
LHS->getOpcode() == ISD::BITCAST) {
2805+
EVT FromVT = LHS->getOperand(0).getValueType();
2806+
if (FromVT.isFixedLengthVector() &&
2807+
FromVT.getVectorElementType() == MVT::i1) {
2808+
int Intrin = isNullConstant(RHS) ? Intrinsic::wasm_anytrue
2809+
: Intrinsic::wasm_alltrue;
2810+
unsigned NumElts = FromVT.getVectorNumElements();
2811+
assert(NumElts == 2 || NumElts == 4 || NumElts == 8 || NumElts == 16);
2812+
EVT Width = MVT::getIntegerVT(128 / NumElts);
2813+
SDValue Ret = DAG.getZExtOrTrunc(
2814+
DAG.getNode(
2815+
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
2816+
{DAG.getConstant(Intrin, DL, MVT::i32),
2817+
DAG.getSExtOrTrunc(LHS->getOperand(0), DL,
2818+
FromVT.changeVectorElementType(Width))}),
2819+
DL, MVT::i1);
2820+
if ((isNullConstant(RHS) && (Cond == ISD::SETEQ)) ||
2821+
(isAllOnesConstant(RHS) && (Cond == ISD::SETNE))) {
2822+
Ret = DAG.getNOT(DL, Ret, MVT::i1);
2823+
}
2824+
return DAG.getZExtOrTrunc(Ret, DL, VT);
2825+
}
2826+
}
2827+
2828+
return SDValue();
2829+
}
2830+
26952831
SDValue
26962832
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
26972833
DAGCombinerInfo &DCI) const {
26982834
switch (N->getOpcode()) {
26992835
default:
27002836
return SDValue();
2837+
case ISD::BITCAST:
2838+
return performBitcastCombine(N, DCI);
2839+
case ISD::SETCC:
2840+
return performSETCCCombine(N, DCI);
27012841
case ISD::VECTOR_SHUFFLE:
27022842
return performVECTOR_SHUFFLECombine(N, DCI);
27032843
case ISD::SIGN_EXTEND:

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
131131
SDValue LowerCopyToReg(SDValue Op, SelectionDAG &DAG) const;
132132
SDValue LowerIntrinsic(SDValue Op, SelectionDAG &DAG) const;
133133
SDValue LowerSIGN_EXTEND_INREG(SDValue Op, SelectionDAG &DAG) const;
134+
SDValue LowerEXTEND_VECTOR_INREG(SDValue Op, SelectionDAG &DAG) const;
134135
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
135136
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
136137
SDValue LowerSETCC(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/WebAssembly/simd-extending-convert.ll

+12-17
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ define <4 x float> @extend_to_float_low_i8x16_u(<8 x i8> %x) {
3636
; CHECK-LABEL: extend_to_float_low_i8x16_u:
3737
; CHECK: .functype extend_to_float_low_i8x16_u (v128) -> (v128)
3838
; CHECK-NEXT: # %bb.0:
39-
; CHECK-NEXT: v128.const 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
4039
; CHECK-NEXT: local.get 0
41-
; CHECK-NEXT: i8x16.shuffle 16, 1, 2, 3, 17, 5, 6, 7, 18, 9, 10, 11, 19, 13, 14, 15
40+
; CHECK-NEXT: i16x8.extend_low_i8x16_u
41+
; CHECK-NEXT: i32x4.extend_low_i16x8_u
4242
; CHECK-NEXT: f32x4.convert_i32x4_u
4343
; CHECK-NEXT: # fallthrough-return
4444
%low = shufflevector <8 x i8> %x, <8 x i8> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
@@ -51,8 +51,10 @@ define <4 x float> @extend_to_float_high_i8x16_u(<8 x i8> %x) {
5151
; CHECK: .functype extend_to_float_high_i8x16_u (v128) -> (v128)
5252
; CHECK-NEXT: # %bb.0:
5353
; CHECK-NEXT: local.get 0
54-
; CHECK-NEXT: v128.const 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
55-
; CHECK-NEXT: i8x16.shuffle 4, 17, 18, 19, 5, 21, 22, 23, 6, 25, 26, 27, 7, 29, 30, 31
54+
; CHECK-NEXT: local.get 0
55+
; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
56+
; CHECK-NEXT: i16x8.extend_low_i8x16_u
57+
; CHECK-NEXT: i32x4.extend_low_i16x8_u
5658
; CHECK-NEXT: f32x4.convert_i32x4_u
5759
; CHECK-NEXT: # fallthrough-return
5860
%high = shufflevector <8 x i8> %x, <8 x i8> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
@@ -91,12 +93,8 @@ define <4 x float> @extend_to_float_low_i8x16_s(<8 x i8> %x) {
9193
; CHECK: .functype extend_to_float_low_i8x16_s (v128) -> (v128)
9294
; CHECK-NEXT: # %bb.0:
9395
; CHECK-NEXT: local.get 0
94-
; CHECK-NEXT: local.get 0
95-
; CHECK-NEXT: i8x16.shuffle 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0
96-
; CHECK-NEXT: i32.const 24
97-
; CHECK-NEXT: i32x4.shl
98-
; CHECK-NEXT: i32.const 24
99-
; CHECK-NEXT: i32x4.shr_s
96+
; CHECK-NEXT: i16x8.extend_low_i8x16_s
97+
; CHECK-NEXT: i32x4.extend_low_i16x8_s
10098
; CHECK-NEXT: f32x4.convert_i32x4_s
10199
; CHECK-NEXT: # fallthrough-return
102100
%low = shufflevector <8 x i8> %x, <8 x i8> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
@@ -110,11 +108,9 @@ define <4 x float> @extend_to_float_high_i8x16_s(<8 x i8> %x) {
110108
; CHECK-NEXT: # %bb.0:
111109
; CHECK-NEXT: local.get 0
112110
; CHECK-NEXT: local.get 0
113-
; CHECK-NEXT: i8x16.shuffle 4, 0, 0, 0, 5, 0, 0, 0, 6, 0, 0, 0, 7, 0, 0, 0
114-
; CHECK-NEXT: i32.const 24
115-
; CHECK-NEXT: i32x4.shl
116-
; CHECK-NEXT: i32.const 24
117-
; CHECK-NEXT: i32x4.shr_s
111+
; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
112+
; CHECK-NEXT: i16x8.extend_low_i8x16_s
113+
; CHECK-NEXT: i32x4.extend_low_i16x8_s
118114
; CHECK-NEXT: f32x4.convert_i32x4_s
119115
; CHECK-NEXT: # fallthrough-return
120116
%high = shufflevector <8 x i8> %x, <8 x i8> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
@@ -138,9 +134,8 @@ define <2 x double> @extend_to_double_low_i16x4_u(<4 x i16> %x) {
138134
; CHECK-LABEL: extend_to_double_low_i16x4_u:
139135
; CHECK: .functype extend_to_double_low_i16x4_u (v128) -> (v128)
140136
; CHECK-NEXT: # %bb.0:
141-
; CHECK-NEXT: v128.const 0, 0, 0, 0, 0, 0, 0, 0
142137
; CHECK-NEXT: local.get 0
143-
; CHECK-NEXT: i8x16.shuffle 16, 17, 2, 3, 18, 19, 6, 7, 20, 21, 10, 11, 22, 23, 14, 15
138+
; CHECK-NEXT: i32x4.extend_low_i16x8_u
144139
; CHECK-NEXT: f64x2.convert_low_i32x4_u
145140
; CHECK-NEXT: # fallthrough-return
146141
%low = shufflevector <4 x i16> %x, <4 x i16> undef, <2 x i32> <i32 0, i32 1>

llvm/test/CodeGen/WebAssembly/simd-extending.ll

+74-10
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,8 @@ define <8 x i16> @extend_lowish_i8x16_s(<16 x i8> %v) {
170170
; CHECK-NEXT: # %bb.0:
171171
; CHECK-NEXT: local.get 0
172172
; CHECK-NEXT: local.get 0
173-
; CHECK-NEXT: i8x16.shuffle 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0
174-
; CHECK-NEXT: i32.const 8
175-
; CHECK-NEXT: i16x8.shl
176-
; CHECK-NEXT: i32.const 8
177-
; CHECK-NEXT: i16x8.shr_s
173+
; CHECK-NEXT: i8x16.shuffle 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0
174+
; CHECK-NEXT: i16x8.extend_low_i8x16_s
178175
; CHECK-NEXT: # fallthrough-return
179176
%lowish = shufflevector <16 x i8> %v, <16 x i8> undef,
180177
<8 x i32> <i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
@@ -188,14 +185,81 @@ define <4 x i32> @extend_lowish_i16x8_s(<8 x i16> %v) {
188185
; CHECK-NEXT: # %bb.0:
189186
; CHECK-NEXT: local.get 0
190187
; CHECK-NEXT: local.get 0
191-
; CHECK-NEXT: i8x16.shuffle 2, 3, 0, 1, 4, 5, 0, 1, 6, 7, 0, 1, 8, 9, 0, 1
192-
; CHECK-NEXT: i32.const 16
193-
; CHECK-NEXT: i32x4.shl
194-
; CHECK-NEXT: i32.const 16
195-
; CHECK-NEXT: i32x4.shr_s
188+
; CHECK-NEXT: i8x16.shuffle 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 0, 1, 0, 1, 0, 1
189+
; CHECK-NEXT: i32x4.extend_low_i16x8_s
196190
; CHECK-NEXT: # fallthrough-return
197191
%lowish = shufflevector <8 x i16> %v, <8 x i16> undef,
198192
<4 x i32> <i32 1, i32 2, i32 3, i32 4>
199193
%extended = sext <4 x i16> %lowish to <4 x i32>
200194
ret <4 x i32> %extended
201195
}
196+
197+
;; Also test vectors that aren't full 128 bits, or might require
198+
;; multiple extensions
199+
200+
define <16 x i8> @extend_i1x16_i8(<16 x i1> %v) {
201+
; CHECK-LABEL: extend_i1x16_i8:
202+
; CHECK: .functype extend_i1x16_i8 (v128) -> (v128)
203+
; CHECK-NEXT: # %bb.0:
204+
; CHECK-NEXT: local.get 0
205+
; CHECK-NEXT: i32.const 7
206+
; CHECK-NEXT: i8x16.shl
207+
; CHECK-NEXT: i32.const 7
208+
; CHECK-NEXT: i8x16.shr_s
209+
; CHECK-NEXT: # fallthrough-return
210+
%extended = sext <16 x i1> %v to <16 x i8>
211+
ret <16 x i8> %extended
212+
}
213+
214+
define <8 x i8> @extend_i1x8_i8(<8 x i1> %v) {
215+
; CHECK-LABEL: extend_i1x8_i8:
216+
; CHECK: .functype extend_i1x8_i8 (v128) -> (v128)
217+
; CHECK-NEXT: # %bb.0:
218+
; CHECK-NEXT: local.get 0
219+
; CHECK-NEXT: local.get 0
220+
; CHECK-NEXT: i8x16.shuffle 0, 2, 4, 6, 8, 10, 12, 14, 0, 0, 0, 0, 0, 0, 0, 0
221+
; CHECK-NEXT: i32.const 7
222+
; CHECK-NEXT: i8x16.shl
223+
; CHECK-NEXT: i32.const 7
224+
; CHECK-NEXT: i8x16.shr_s
225+
; CHECK-NEXT: # fallthrough-return
226+
%extended = sext <8 x i1> %v to <8 x i8>
227+
ret <8 x i8> %extended
228+
}
229+
230+
define <8 x i16> @extend_i1x8_i16(<8 x i1> %v) {
231+
; CHECK-LABEL: extend_i1x8_i16:
232+
; CHECK: .functype extend_i1x8_i16 (v128) -> (v128)
233+
; CHECK-NEXT: # %bb.0:
234+
; CHECK-NEXT: local.get 0
235+
; CHECK-NEXT: v128.const 1, 1, 1, 1, 1, 1, 1, 1
236+
; CHECK-NEXT: v128.and
237+
; CHECK-NEXT: # fallthrough-return
238+
%extended = zext <8 x i1> %v to <8 x i16>
239+
ret <8 x i16> %extended
240+
}
241+
242+
define <4 x i32> @extend_i8x4_i32(<4 x i8> %v) {
243+
; CHECK-LABEL: extend_i8x4_i32:
244+
; CHECK: .functype extend_i8x4_i32 (v128) -> (v128)
245+
; CHECK-NEXT: # %bb.0:
246+
; CHECK-NEXT: local.get 0
247+
; CHECK-NEXT: i16x8.extend_low_i8x16_u
248+
; CHECK-NEXT: i32x4.extend_low_i16x8_u
249+
; CHECK-NEXT: # fallthrough-return
250+
%extended = zext <4 x i8> %v to <4 x i32>
251+
ret <4 x i32> %extended
252+
}
253+
254+
define <2 x i64> @extend_i8x2_i64(<2 x i8> %v) {
255+
; CHECK-LABEL: extend_i8x2_i64:
256+
; CHECK: .functype extend_i8x2_i64 (v128) -> (v128)
257+
; CHECK-NEXT: # %bb.0:
258+
; CHECK-NEXT: local.get 0
259+
; CHECK-NEXT: i16x8.extend_low_i8x16_s
260+
; CHECK-NEXT: i32x4.extend_low_i16x8_s
261+
; CHECK-NEXT: i64x2.extend_low_i32x4_s
262+
; CHECK-NEXT: # fallthrough-return
263+
%extended = sext <2 x i8> %v to <2 x i64>
264+
ret <2 x i64> %extended
265+
}

0 commit comments

Comments
 (0)