Skip to content

Commit a916e81

Browse files
committed
[X86] Various improvements to our vector splitting helpers for lowering. NFC
-Consistently name the functions as split* -Add a helper for doing the two extractSubvector calls and determining the size of the split -Use getSplitDestVTs to get the result type for the split node. -Move the binary and unary helper to one place in the file near the extractSubvector functions. Left the VSETCC one near LowerVSETCC since that's its only caller. -Remove the 256/512 wrappers that just had asserts. I don't think they provided a lot of value and now with the routines called split* the call sites are more obvious what they do. -Make the unary routine support different source and dest types to support D76212. -Add some weaker asserts into the helpers to make up for losing the very specific asserts from the 256/512 wrappers. Differential Revision: https://reviews.llvm.org/D78176
1 parent 7ce1a93 commit a916e81

File tree

1 file changed

+100
-116
lines changed

1 file changed

+100
-116
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 100 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -5802,6 +5802,71 @@ static bool collectConcatOps(SDNode *N, SmallVectorImpl<SDValue> &Ops) {
58025802
return false;
58035803
}
58045804

5805+
static std::pair<SDValue, SDValue> splitVector(SDValue Op, SelectionDAG &DAG,
5806+
const SDLoc &dl) {
5807+
MVT VT = Op.getSimpleValueType();
5808+
unsigned NumElems = VT.getVectorNumElements();
5809+
unsigned SizeInBits = VT.getSizeInBits();
5810+
5811+
SDValue Lo = extractSubVector(Op, 0, DAG, dl, SizeInBits / 2);
5812+
SDValue Hi = extractSubVector(Op, NumElems / 2, DAG, dl, SizeInBits / 2);
5813+
5814+
return std::make_pair(Lo, Hi);
5815+
}
5816+
5817+
// Split an unary integer op into 2 half sized ops.
5818+
static SDValue splitVectorIntUnary(SDValue Op, SelectionDAG &DAG) {
5819+
EVT VT = Op.getValueType();
5820+
5821+
// Make sure we only try to split 256/512-bit types to avoid creating
5822+
// narrow vectors.
5823+
assert((Op.getOperand(0).getValueType().is256BitVector() ||
5824+
Op.getOperand(0).getValueType().is512BitVector()) &&
5825+
(VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!");
5826+
assert(Op.getOperand(0).getValueType().getVectorNumElements() ==
5827+
VT.getVectorNumElements() &&
5828+
"Unexpected VTs!");
5829+
5830+
SDLoc dl(Op);
5831+
5832+
// Extract the Lo/Hi vectors
5833+
SDValue Lo, Hi;
5834+
std::tie(Lo, Hi) = splitVector(Op.getOperand(0), DAG, dl);
5835+
5836+
EVT LoVT, HiVT;
5837+
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
5838+
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
5839+
DAG.getNode(Op.getOpcode(), dl, LoVT, Lo),
5840+
DAG.getNode(Op.getOpcode(), dl, HiVT, Hi));
5841+
}
5842+
5843+
/// Break a binary integer operation into 2 half sized ops and then
5844+
/// concatenate the result back.
5845+
static SDValue splitVectorIntBinary(SDValue Op, SelectionDAG &DAG) {
5846+
EVT VT = Op.getValueType();
5847+
5848+
// Sanity check that all the types match.
5849+
assert(Op.getOperand(0).getValueType() == VT &&
5850+
Op.getOperand(1).getValueType() == VT && "Unexpected VTs!");
5851+
assert((VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!");
5852+
5853+
SDLoc dl(Op);
5854+
5855+
// Extract the LHS Lo/Hi vectors
5856+
SDValue LHS1, LHS2;
5857+
std::tie(LHS1, LHS2) = splitVector(Op.getOperand(0), DAG, dl);
5858+
5859+
// Extract the RHS Lo/Hi vectors
5860+
SDValue RHS1, RHS2;
5861+
std::tie(RHS1, RHS2) = splitVector(Op.getOperand(1), DAG, dl);
5862+
5863+
EVT LoVT, HiVT;
5864+
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
5865+
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
5866+
DAG.getNode(Op.getOpcode(), dl, LoVT, LHS1, RHS1),
5867+
DAG.getNode(Op.getOpcode(), dl, HiVT, LHS2, RHS2));
5868+
}
5869+
58055870
// Helper for splitting operands of an operation to legal target size and
58065871
// apply a function on each part.
58075872
// Useful for operations that are available on SSE2 in 128-bit, on AVX2 in
@@ -21820,32 +21885,30 @@ static unsigned translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0,
2182021885

2182121886
/// Break a VSETCC 256-bit integer VSETCC into two new 128 ones and then
2182221887
/// concatenate the result back.
21823-
static SDValue Lower256IntVSETCC(SDValue Op, SelectionDAG &DAG) {
21824-
MVT VT = Op.getSimpleValueType();
21888+
static SDValue splitIntVSETCC(SDValue Op, SelectionDAG &DAG) {
21889+
EVT VT = Op.getValueType();
2182521890

21826-
assert(VT.is256BitVector() && Op.getOpcode() == ISD::SETCC &&
21827-
"Unsupported value type for operation");
21891+
assert(Op.getOpcode() == ISD::SETCC && "Unsupported operation");
21892+
assert(Op.getOperand(0).getValueType().isInteger() &&
21893+
VT == Op.getOperand(0).getValueType() && "Unsupported VTs!");
2182821894

21829-
unsigned NumElems = VT.getVectorNumElements();
2183021895
SDLoc dl(Op);
2183121896
SDValue CC = Op.getOperand(2);
2183221897

21833-
// Extract the LHS vectors
21834-
SDValue LHS = Op.getOperand(0);
21835-
SDValue LHS1 = extract128BitVector(LHS, 0, DAG, dl);
21836-
SDValue LHS2 = extract128BitVector(LHS, NumElems / 2, DAG, dl);
21898+
// Extract the LHS Lo/Hi vectors
21899+
SDValue LHS1, LHS2;
21900+
std::tie(LHS1, LHS2) = splitVector(Op.getOperand(0), DAG, dl);
2183721901

21838-
// Extract the RHS vectors
21839-
SDValue RHS = Op.getOperand(1);
21840-
SDValue RHS1 = extract128BitVector(RHS, 0, DAG, dl);
21841-
SDValue RHS2 = extract128BitVector(RHS, NumElems / 2, DAG, dl);
21902+
// Extract the RHS Lo/Hi vectors
21903+
SDValue RHS1, RHS2;
21904+
std::tie(RHS1, RHS2) = splitVector(Op.getOperand(1), DAG, dl);
2184221905

2184321906
// Issue the operation on the smaller types and concatenate the result back
21844-
MVT EltVT = VT.getVectorElementType();
21845-
MVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);
21907+
EVT LoVT, HiVT;
21908+
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
2184621909
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
21847-
DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, RHS1, CC),
21848-
DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2, CC));
21910+
DAG.getNode(ISD::SETCC, dl, LoVT, LHS1, RHS1, CC),
21911+
DAG.getNode(ISD::SETCC, dl, HiVT, LHS2, RHS2, CC));
2184921912
}
2185021913

2185121914
static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) {
@@ -22187,7 +22250,7 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
2218722250

2218822251
// Break 256-bit integer vector compare into smaller ones.
2218922252
if (VT.is256BitVector() && !Subtarget.hasInt256())
22190-
return Lower256IntVSETCC(Op, DAG);
22253+
return splitIntVSETCC(Op, DAG);
2219122254

2219222255
// If this is a SETNE against the signed minimum value, change it to SETGT.
2219322256
// If this is a SETNE against the signed maximum value, change it to SETLT.
@@ -25922,43 +25985,6 @@ SDValue X86TargetLowering::LowerFLT_ROUNDS_(SDValue Op,
2592225985
return DAG.getMergeValues({RetVal, Chain}, DL);
2592325986
}
2592425987

25925-
// Split an unary integer op into 2 half sized ops.
25926-
static SDValue LowerVectorIntUnary(SDValue Op, SelectionDAG &DAG) {
25927-
MVT VT = Op.getSimpleValueType();
25928-
unsigned NumElems = VT.getVectorNumElements();
25929-
unsigned SizeInBits = VT.getSizeInBits();
25930-
MVT EltVT = VT.getVectorElementType();
25931-
SDValue Src = Op.getOperand(0);
25932-
assert(EltVT == Src.getSimpleValueType().getVectorElementType() &&
25933-
"Src and Op should have the same element type!");
25934-
25935-
// Extract the Lo/Hi vectors
25936-
SDLoc dl(Op);
25937-
SDValue Lo = extractSubVector(Src, 0, DAG, dl, SizeInBits / 2);
25938-
SDValue Hi = extractSubVector(Src, NumElems / 2, DAG, dl, SizeInBits / 2);
25939-
25940-
MVT NewVT = MVT::getVectorVT(EltVT, NumElems / 2);
25941-
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
25942-
DAG.getNode(Op.getOpcode(), dl, NewVT, Lo),
25943-
DAG.getNode(Op.getOpcode(), dl, NewVT, Hi));
25944-
}
25945-
25946-
// Decompose 256-bit ops into smaller 128-bit ops.
25947-
static SDValue Lower256IntUnary(SDValue Op, SelectionDAG &DAG) {
25948-
assert(Op.getSimpleValueType().is256BitVector() &&
25949-
Op.getSimpleValueType().isInteger() &&
25950-
"Only handle AVX 256-bit vector integer operation");
25951-
return LowerVectorIntUnary(Op, DAG);
25952-
}
25953-
25954-
// Decompose 512-bit ops into smaller 256-bit ops.
25955-
static SDValue Lower512IntUnary(SDValue Op, SelectionDAG &DAG) {
25956-
assert(Op.getSimpleValueType().is512BitVector() &&
25957-
Op.getSimpleValueType().isInteger() &&
25958-
"Only handle AVX 512-bit vector integer operation");
25959-
return LowerVectorIntUnary(Op, DAG);
25960-
}
25961-
2596225988
/// Lower a vector CTLZ using native supported vector CTLZ instruction.
2596325989
//
2596425990
// i8/i16 vector implemented using dword LZCNT vector instruction
@@ -25979,7 +26005,7 @@ static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG,
2597926005
// Split vector, it's Lo and Hi parts will be handled in next iteration.
2598026006
if (NumElems > 16 ||
2598126007
(NumElems == 16 && !Subtarget.canExtendTo512DQ()))
25982-
return LowerVectorIntUnary(Op, DAG);
26008+
return splitVectorIntUnary(Op, DAG);
2598326009

2598426010
MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems);
2598526011
assert((NewVT.is256BitVector() || NewVT.is512BitVector()) &&
@@ -26089,11 +26115,11 @@ static SDValue LowerVectorCTLZ(SDValue Op, const SDLoc &DL,
2608926115

2609026116
// Decompose 256-bit ops into smaller 128-bit ops.
2609126117
if (VT.is256BitVector() && !Subtarget.hasInt256())
26092-
return Lower256IntUnary(Op, DAG);
26118+
return splitVectorIntUnary(Op, DAG);
2609326119

2609426120
// Decompose 512-bit ops into smaller 256-bit ops.
2609526121
if (VT.is512BitVector() && !Subtarget.hasBWI())
26096-
return Lower512IntUnary(Op, DAG);
26122+
return splitVectorIntUnary(Op, DAG);
2609726123

2609826124
assert(Subtarget.hasSSSE3() && "Expected SSSE3 support for PSHUFB");
2609926125
return LowerVectorCTLZInRegLUT(Op, DL, Subtarget, DAG);
@@ -26159,48 +26185,6 @@ static SDValue LowerCTTZ(SDValue Op, const X86Subtarget &Subtarget,
2615926185
return DAG.getNode(X86ISD::CMOV, dl, VT, Ops);
2616026186
}
2616126187

26162-
/// Break a binary integer operation into 2 half sized ops and then
26163-
/// concatenate the result back.
26164-
static SDValue splitVectorIntBinary(SDValue Op, SelectionDAG &DAG) {
26165-
MVT VT = Op.getSimpleValueType();
26166-
unsigned NumElems = VT.getVectorNumElements();
26167-
unsigned SizeInBits = VT.getSizeInBits();
26168-
SDLoc dl(Op);
26169-
26170-
// Extract the LHS Lo/Hi vectors
26171-
SDValue LHS = Op.getOperand(0);
26172-
SDValue LHS1 = extractSubVector(LHS, 0, DAG, dl, SizeInBits / 2);
26173-
SDValue LHS2 = extractSubVector(LHS, NumElems / 2, DAG, dl, SizeInBits / 2);
26174-
26175-
// Extract the RHS Lo/Hi vectors
26176-
SDValue RHS = Op.getOperand(1);
26177-
SDValue RHS1 = extractSubVector(RHS, 0, DAG, dl, SizeInBits / 2);
26178-
SDValue RHS2 = extractSubVector(RHS, NumElems / 2, DAG, dl, SizeInBits / 2);
26179-
26180-
MVT NewVT = MVT::getVectorVT(VT.getVectorElementType(), NumElems / 2);
26181-
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
26182-
DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, RHS1),
26183-
DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2));
26184-
}
26185-
26186-
/// Break a 256-bit integer operation into two new 128-bit ones and then
26187-
/// concatenate the result back.
26188-
static SDValue split256IntArith(SDValue Op, SelectionDAG &DAG) {
26189-
assert(Op.getSimpleValueType().is256BitVector() &&
26190-
Op.getSimpleValueType().isInteger() &&
26191-
"Unsupported value type for operation");
26192-
return splitVectorIntBinary(Op, DAG);
26193-
}
26194-
26195-
/// Break a 512-bit integer operation into two new 256-bit ones and then
26196-
/// concatenate the result back.
26197-
static SDValue split512IntArith(SDValue Op, SelectionDAG &DAG) {
26198-
assert(Op.getSimpleValueType().is512BitVector() &&
26199-
Op.getSimpleValueType().isInteger() &&
26200-
"Unsupported value type for operation");
26201-
return splitVectorIntBinary(Op, DAG);
26202-
}
26203-
2620426188
static SDValue lowerAddSub(SDValue Op, SelectionDAG &DAG,
2620526189
const X86Subtarget &Subtarget) {
2620626190
MVT VT = Op.getSimpleValueType();
@@ -26214,7 +26198,7 @@ static SDValue lowerAddSub(SDValue Op, SelectionDAG &DAG,
2621426198
assert(Op.getSimpleValueType().is256BitVector() &&
2621526199
Op.getSimpleValueType().isInteger() &&
2621626200
"Only handle AVX 256-bit vector integer operation");
26217-
return split256IntArith(Op, DAG);
26201+
return splitVectorIntBinary(Op, DAG);
2621826202
}
2621926203

2622026204
static SDValue LowerADDSAT_SUBSAT(SDValue Op, SelectionDAG &DAG,
@@ -26262,7 +26246,7 @@ static SDValue LowerADDSAT_SUBSAT(SDValue Op, SelectionDAG &DAG,
2626226246
assert(Op.getSimpleValueType().is256BitVector() &&
2626326247
Op.getSimpleValueType().isInteger() &&
2626426248
"Only handle AVX 256-bit vector integer operation");
26265-
return split256IntArith(Op, DAG);
26249+
return splitVectorIntBinary(Op, DAG);
2626626250
}
2626726251

2626826252
static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget,
@@ -26292,7 +26276,7 @@ static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget,
2629226276
if (VT.is256BitVector() && !Subtarget.hasInt256()) {
2629326277
assert(VT.isInteger() &&
2629426278
"Only handle AVX 256-bit vector integer operation");
26295-
return Lower256IntUnary(Op, DAG);
26279+
return splitVectorIntUnary(Op, DAG);
2629626280
}
2629726281

2629826282
// Default to expand.
@@ -26304,7 +26288,7 @@ static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) {
2630426288

2630526289
// For AVX1 cases, split to use legal ops (everything but v4i64).
2630626290
if (VT.getScalarType() != MVT::i64 && VT.is256BitVector())
26307-
return split256IntArith(Op, DAG);
26291+
return splitVectorIntBinary(Op, DAG);
2630826292

2630926293
SDLoc DL(Op);
2631026294
unsigned Opcode = Op.getOpcode();
@@ -26348,7 +26332,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
2634826332

2634926333
// Decompose 256-bit ops into 128-bit ops.
2635026334
if (VT.is256BitVector() && !Subtarget.hasInt256())
26351-
return split256IntArith(Op, DAG);
26335+
return splitVectorIntBinary(Op, DAG);
2635226336

2635326337
SDValue A = Op.getOperand(0);
2635426338
SDValue B = Op.getOperand(1);
@@ -26494,7 +26478,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
2649426478

2649526479
// Decompose 256-bit ops into 128-bit ops.
2649626480
if (VT.is256BitVector() && !Subtarget.hasInt256())
26497-
return split256IntArith(Op, DAG);
26481+
return splitVectorIntBinary(Op, DAG);
2649826482

2649926483
if (VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32) {
2650026484
assert((VT == MVT::v4i32 && Subtarget.hasSSE2()) ||
@@ -26586,7 +26570,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
2658626570
// For signed 512-bit vectors, split into 256-bit vectors to allow the
2658726571
// sign-extension to occur.
2658826572
if (VT == MVT::v64i8 && IsSigned)
26589-
return split512IntArith(Op, DAG);
26573+
return splitVectorIntBinary(Op, DAG);
2659026574

2659126575
// Signed AVX2 implementation - extend xmm subvectors to ymm.
2659226576
if (VT == MVT::v32i8 && IsSigned) {
@@ -27560,7 +27544,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2756027544

2756127545
// Decompose 256-bit shifts into 128-bit shifts.
2756227546
if (VT.is256BitVector())
27563-
return split256IntArith(Op, DAG);
27547+
return splitVectorIntBinary(Op, DAG);
2756427548

2756527549
return SDValue();
2756627550
}
@@ -27606,7 +27590,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
2760627590
// XOP implicitly uses modulo rotation amounts.
2760727591
if (Subtarget.hasXOP()) {
2760827592
if (VT.is256BitVector())
27609-
return split256IntArith(Op, DAG);
27593+
return splitVectorIntBinary(Op, DAG);
2761027594
assert(VT.is128BitVector() && "Only rotate 128-bit vectors!");
2761127595

2761227596
// Attempt to rotate by immediate.
@@ -27622,7 +27606,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
2762227606

2762327607
// Split 256-bit integers on pre-AVX2 targets.
2762427608
if (VT.is256BitVector() && !Subtarget.hasAVX2())
27625-
return split256IntArith(Op, DAG);
27609+
return splitVectorIntBinary(Op, DAG);
2762627610

2762727611
assert((VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8 ||
2762827612
((VT == MVT::v8i32 || VT == MVT::v16i16 || VT == MVT::v32i8) &&
@@ -28287,11 +28271,11 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget,
2828728271

2828828272
// Decompose 256-bit ops into smaller 128-bit ops.
2828928273
if (VT.is256BitVector() && !Subtarget.hasInt256())
28290-
return Lower256IntUnary(Op, DAG);
28274+
return splitVectorIntUnary(Op, DAG);
2829128275

2829228276
// Decompose 512-bit ops into smaller 256-bit ops.
2829328277
if (VT.is512BitVector() && !Subtarget.hasBWI())
28294-
return Lower512IntUnary(Op, DAG);
28278+
return splitVectorIntUnary(Op, DAG);
2829528279

2829628280
// For element types greater than i8, do vXi8 pop counts and a bytesum.
2829728281
if (VT.getScalarType() != MVT::i8) {
@@ -28335,7 +28319,7 @@ static SDValue LowerBITREVERSE_XOP(SDValue Op, SelectionDAG &DAG) {
2833528319

2833628320
// Decompose 256-bit ops into smaller 128-bit ops.
2833728321
if (VT.is256BitVector())
28338-
return Lower256IntUnary(Op, DAG);
28322+
return splitVectorIntUnary(Op, DAG);
2833928323

2834028324
assert(VT.is128BitVector() &&
2834128325
"Only 128-bit vector bitreverse lowering supported.");
@@ -28376,7 +28360,7 @@ static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget,
2837628360
// lowering.
2837728361
if (VT == MVT::v8i64 || VT == MVT::v16i32) {
2837828362
assert(!Subtarget.hasBWI() && "BWI should Expand BITREVERSE");
28379-
return Lower512IntUnary(Op, DAG);
28363+
return splitVectorIntUnary(Op, DAG);
2838028364
}
2838128365

2838228366
unsigned NumElts = VT.getVectorNumElements();
@@ -28385,7 +28369,7 @@ static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget,
2838528369

2838628370
// Decompose 256-bit ops into smaller 128-bit ops on pre-AVX2.
2838728371
if (VT.is256BitVector() && !Subtarget.hasInt256())
28388-
return Lower256IntUnary(Op, DAG);
28372+
return splitVectorIntUnary(Op, DAG);
2838928373

2839028374
// Perform BITREVERSE using PSHUFB lookups. Each byte is split into
2839128375
// two nibbles and a PSHUFB lookup to find the bitreverse of each
@@ -47137,7 +47121,7 @@ static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG,
4713747121
if (isConcatenatedNot(InVecBC.getOperand(0)) ||
4713847122
isConcatenatedNot(InVecBC.getOperand(1))) {
4713947123
// extract (and v4i64 X, (not (concat Y1, Y2))), n -> andnp v2i64 X(n), Y1
47140-
SDValue Concat = split256IntArith(InVecBC, DAG);
47124+
SDValue Concat = splitVectorIntBinary(InVecBC, DAG);
4714147125
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT,
4714247126
DAG.getBitcast(InVecVT, Concat), N->getOperand(1));
4714347127
}

0 commit comments

Comments
 (0)