Skip to content

Commit 47321c3

Browse files
committed
[X86][SSE] combineReductionToHorizontal - add vXi8 ISD::MUL reduction handling (PR39709)
Default expansion leads to repeated extensions/truncations to/from vXi16 which shuffle combining and demanded elts can't completely unravel. Better just to promote (any_extend) the input and perform a vXi16 reduction. We'll be able to remove a lot of this if we ever get decent legalization support for reduction intrinsics in SelectionDAG.
1 parent 9c3fa3d commit 47321c3

File tree

3 files changed

+360
-698
lines changed

3 files changed

+360
-698
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6357,8 +6357,10 @@ static SDValue IsNOT(SDValue V, SelectionDAG &DAG, bool OneUse = false) {
63576357
return SDValue();
63586358
}
63596359

6360-
void llvm::createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
6360+
void llvm::createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask,
63616361
bool Lo, bool Unary) {
6362+
assert(VT.getScalarType().isSimple() && (VT.getSizeInBits() % 128) == 0 &&
6363+
"Illegal vector type to unpack");
63626364
assert(Mask.empty() && "Expected an empty shuffle mask vector");
63636365
int NumElts = VT.getVectorNumElements();
63646366
int NumEltsInLane = 128 / VT.getScalarSizeInBits();
@@ -6387,15 +6389,15 @@ void llvm::createSplat2ShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
63876389
}
63886390

63896391
/// Returns a vector_shuffle node for an unpackl operation.
6390-
static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
6392+
static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
63916393
SDValue V1, SDValue V2) {
63926394
SmallVector<int, 8> Mask;
63936395
createUnpackShuffleMask(VT, Mask, /* Lo = */ true, /* Unary = */ false);
63946396
return DAG.getVectorShuffle(VT, dl, V1, V2, Mask);
63956397
}
63966398

63976399
/// Returns a vector_shuffle node for an unpackh operation.
6398-
static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
6400+
static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
63996401
SDValue V1, SDValue V2) {
64006402
SmallVector<int, 8> Mask;
64016403
createUnpackShuffleMask(VT, Mask, /* Lo = */ false, /* Unary = */ false);
@@ -40026,8 +40028,8 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
4002640028
return SDValue();
4002740029

4002840030
ISD::NodeType Opc;
40029-
SDValue Rdx =
40030-
DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD, ISD::FADD}, true);
40031+
SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc,
40032+
{ISD::ADD, ISD::MUL, ISD::FADD}, true);
4003140033
if (!Rdx)
4003240034
return SDValue();
4003340035

@@ -40042,7 +40044,42 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
4004240044

4004340045
SDLoc DL(ExtElt);
4004440046

40045-
// vXi8 reduction - sub 128-bit vector.
40047+
// vXi8 mul reduction - promote to vXi16 mul reduction.
40048+
if (Opc == ISD::MUL) {
40049+
unsigned NumElts = VecVT.getVectorNumElements();
40050+
if (VT != MVT::i8 || NumElts < 4 || !isPowerOf2_32(NumElts))
40051+
return SDValue();
40052+
if (VecVT.getSizeInBits() >= 128) {
40053+
EVT WideVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts / 2);
40054+
SDValue Lo = getUnpackl(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
40055+
SDValue Hi = getUnpackh(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
40056+
Lo = DAG.getBitcast(WideVT, Lo);
40057+
Hi = DAG.getBitcast(WideVT, Hi);
40058+
Rdx = DAG.getNode(Opc, DL, WideVT, Lo, Hi);
40059+
while (Rdx.getValueSizeInBits() > 128) {
40060+
std::tie(Lo, Hi) = splitVector(Rdx, DAG, DL);
40061+
Rdx = DAG.getNode(Opc, DL, Lo.getValueType(), Lo, Hi);
40062+
}
40063+
} else {
40064+
Rdx = widenSubVector(Rdx, false, Subtarget, DAG, DL, 128);
40065+
Rdx = getUnpackl(DAG, DL, MVT::v16i8, Rdx, DAG.getUNDEF(MVT::v16i8));
40066+
Rdx = DAG.getBitcast(MVT::v8i16, Rdx);
40067+
}
40068+
if (NumElts >= 8)
40069+
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
40070+
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
40071+
{4, 5, 6, 7, -1, -1, -1, -1}));
40072+
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
40073+
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
40074+
{2, 3, -1, -1, -1, -1, -1, -1}));
40075+
Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
40076+
DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
40077+
{1, -1, -1, -1, -1, -1, -1, -1}));
40078+
Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
40079+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
40080+
}
40081+
40082+
// vXi8 add reduction - sub 128-bit vector.
4004640083
if (VecVT == MVT::v4i8 || VecVT == MVT::v8i8) {
4004740084
if (VecVT == MVT::v4i8) {
4004840085
// Pad with zero.
@@ -40073,7 +40110,7 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
4007340110
!isPowerOf2_32(VecVT.getVectorNumElements()))
4007440111
return SDValue();
4007540112

40076-
// vXi8 reduction - sum lo/hi halves then use PSADBW.
40113+
// vXi8 add reduction - sum lo/hi halves then use PSADBW.
4007740114
if (VT == MVT::i8) {
4007840115
while (Rdx.getValueSizeInBits() > 128) {
4007940116
SDValue Lo, Hi;

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1698,7 +1698,7 @@ namespace llvm {
16981698
};
16991699

17001700
/// Generate unpacklo/unpackhi shuffle mask.
1701-
void createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask, bool Lo,
1701+
void createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask, bool Lo,
17021702
bool Unary);
17031703

17041704
/// Similar to unpacklo/unpackhi, but without the 128-bit lane limitation

0 commit comments

Comments
 (0)