@@ -6357,8 +6357,10 @@ static SDValue IsNOT(SDValue V, SelectionDAG &DAG, bool OneUse = false) {
6357
6357
return SDValue();
6358
6358
}
6359
6359
6360
- void llvm::createUnpackShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
6360
+ void llvm::createUnpackShuffleMask(EVT VT, SmallVectorImpl<int> &Mask,
6361
6361
bool Lo, bool Unary) {
6362
+ assert(VT.getScalarType().isSimple() && (VT.getSizeInBits() % 128) == 0 &&
6363
+ "Illegal vector type to unpack");
6362
6364
assert(Mask.empty() && "Expected an empty shuffle mask vector");
6363
6365
int NumElts = VT.getVectorNumElements();
6364
6366
int NumEltsInLane = 128 / VT.getScalarSizeInBits();
@@ -6387,15 +6389,15 @@ void llvm::createSplat2ShuffleMask(MVT VT, SmallVectorImpl<int> &Mask,
6387
6389
}
6388
6390
6389
6391
/// Returns a vector_shuffle node for an unpackl operation.
6390
- static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
6392
+ static SDValue getUnpackl(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
6391
6393
SDValue V1, SDValue V2) {
6392
6394
SmallVector<int, 8> Mask;
6393
6395
createUnpackShuffleMask(VT, Mask, /* Lo = */ true, /* Unary = */ false);
6394
6396
return DAG.getVectorShuffle(VT, dl, V1, V2, Mask);
6395
6397
}
6396
6398
6397
6399
/// Returns a vector_shuffle node for an unpackh operation.
6398
- static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, MVT VT,
6400
+ static SDValue getUnpackh(SelectionDAG &DAG, const SDLoc &dl, EVT VT,
6399
6401
SDValue V1, SDValue V2) {
6400
6402
SmallVector<int, 8> Mask;
6401
6403
createUnpackShuffleMask(VT, Mask, /* Lo = */ false, /* Unary = */ false);
@@ -40026,8 +40028,8 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
40026
40028
return SDValue();
40027
40029
40028
40030
ISD::NodeType Opc;
40029
- SDValue Rdx =
40030
- DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD, ISD::FADD}, true);
40031
+ SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc,
40032
+ {ISD::ADD, ISD::MUL , ISD::FADD}, true);
40031
40033
if (!Rdx)
40032
40034
return SDValue();
40033
40035
@@ -40042,7 +40044,42 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
40042
40044
40043
40045
SDLoc DL(ExtElt);
40044
40046
40045
- // vXi8 reduction - sub 128-bit vector.
40047
+ // vXi8 mul reduction - promote to vXi16 mul reduction.
40048
+ if (Opc == ISD::MUL) {
40049
+ unsigned NumElts = VecVT.getVectorNumElements();
40050
+ if (VT != MVT::i8 || NumElts < 4 || !isPowerOf2_32(NumElts))
40051
+ return SDValue();
40052
+ if (VecVT.getSizeInBits() >= 128) {
40053
+ EVT WideVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts / 2);
40054
+ SDValue Lo = getUnpackl(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
40055
+ SDValue Hi = getUnpackh(DAG, DL, VecVT, Rdx, DAG.getUNDEF(VecVT));
40056
+ Lo = DAG.getBitcast(WideVT, Lo);
40057
+ Hi = DAG.getBitcast(WideVT, Hi);
40058
+ Rdx = DAG.getNode(Opc, DL, WideVT, Lo, Hi);
40059
+ while (Rdx.getValueSizeInBits() > 128) {
40060
+ std::tie(Lo, Hi) = splitVector(Rdx, DAG, DL);
40061
+ Rdx = DAG.getNode(Opc, DL, Lo.getValueType(), Lo, Hi);
40062
+ }
40063
+ } else {
40064
+ Rdx = widenSubVector(Rdx, false, Subtarget, DAG, DL, 128);
40065
+ Rdx = getUnpackl(DAG, DL, MVT::v16i8, Rdx, DAG.getUNDEF(MVT::v16i8));
40066
+ Rdx = DAG.getBitcast(MVT::v8i16, Rdx);
40067
+ }
40068
+ if (NumElts >= 8)
40069
+ Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
40070
+ DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
40071
+ {4, 5, 6, 7, -1, -1, -1, -1}));
40072
+ Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
40073
+ DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
40074
+ {2, 3, -1, -1, -1, -1, -1, -1}));
40075
+ Rdx = DAG.getNode(Opc, DL, MVT::v8i16, Rdx,
40076
+ DAG.getVectorShuffle(MVT::v8i16, DL, Rdx, Rdx,
40077
+ {1, -1, -1, -1, -1, -1, -1, -1}));
40078
+ Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
40079
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Rdx, Index);
40080
+ }
40081
+
40082
+ // vXi8 add reduction - sub 128-bit vector.
40046
40083
if (VecVT == MVT::v4i8 || VecVT == MVT::v8i8) {
40047
40084
if (VecVT == MVT::v4i8) {
40048
40085
// Pad with zero.
@@ -40073,7 +40110,7 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG,
40073
40110
!isPowerOf2_32(VecVT.getVectorNumElements()))
40074
40111
return SDValue();
40075
40112
40076
- // vXi8 reduction - sum lo/hi halves then use PSADBW.
40113
+ // vXi8 add reduction - sum lo/hi halves then use PSADBW.
40077
40114
if (VT == MVT::i8) {
40078
40115
while (Rdx.getValueSizeInBits() > 128) {
40079
40116
SDValue Lo, Hi;
0 commit comments