Skip to content

Commit c959357

Browse files
authored
[RISCV] Directly use pack* in build_vector lowering (llvm#98084)
In 03d4332, we extended build_vector lowering to pack elements into the largest size which doesn't exceed either ELEN or XLEN. The zbkb extension - ratified under scalar crypto, but otherwise not really connected to crypto per se - adds the packh, packw, and pack instructions. These instructions are designed for exactly this pairwise packing. I ended up choosing to directly lower to machine nodes. A combination of the slightly non-uniform semantics of these instructions (packw *sign* extends the result, whereas packh *zero* extends it), and our generic dag canonicalization (which sinks shl through or nodes), make pattern matching these tricky and not particularly robust. Another alternative was to have an ISD node for them, but that didn't seem to add much in practice.
1 parent a004e50 commit c959357

File tree

2 files changed

+267
-330
lines changed

2 files changed

+267
-330
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3905,6 +3905,21 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
39053905
return SDValue();
39063906
}
39073907

3908+
static unsigned getPACKOpcode(unsigned DestBW,
3909+
const RISCVSubtarget &Subtarget) {
3910+
switch (DestBW) {
3911+
default:
3912+
llvm_unreachable("Unsupported pack size");
3913+
case 16:
3914+
return RISCV::PACKH;
3915+
case 32:
3916+
return Subtarget.is64Bit() ? RISCV::PACKW : RISCV::PACK;
3917+
case 64:
3918+
assert(Subtarget.is64Bit());
3919+
return RISCV::PACK;
3920+
}
3921+
}
3922+
39083923
/// Double the element size of the build vector to reduce the number
39093924
/// of vslide1down in the build vector chain. In the worst case, this
39103925
/// trades three scalar operations for 1 vector operation. Scalar
@@ -3933,30 +3948,34 @@ static SDValue lowerBuildVectorViaPacking(SDValue Op, SelectionDAG &DAG,
39333948
// Produce [B,A] packed into a type twice as wide. Note that all
39343949
// scalars are XLenVT, possibly masked (see below).
39353950
MVT XLenVT = Subtarget.getXLenVT();
3951+
SDValue Mask = DAG.getConstant(
3952+
APInt::getLowBitsSet(XLenVT.getSizeInBits(), ElemSizeInBits), DL, XLenVT);
39363953
auto pack = [&](SDValue A, SDValue B) {
39373954
// Bias the scheduling of the inserted operations to near the
39383955
// definition of the element - this tends to reduce register
39393956
// pressure overall.
39403957
SDLoc ElemDL(B);
3958+
if (Subtarget.hasStdExtZbkb())
3959+
// Note that we're relying on the high bits of the result being
3960+
// don't care. For PACKW, the result is *sign* extended.
3961+
return SDValue(
3962+
DAG.getMachineNode(getPACKOpcode(ElemSizeInBits * 2, Subtarget),
3963+
ElemDL, XLenVT, A, B),
3964+
0);
3965+
3966+
A = DAG.getNode(ISD::AND, SDLoc(A), XLenVT, A, Mask);
3967+
B = DAG.getNode(ISD::AND, SDLoc(B), XLenVT, B, Mask);
39413968
SDValue ShtAmt = DAG.getConstant(ElemSizeInBits, ElemDL, XLenVT);
3969+
SDNodeFlags Flags;
3970+
Flags.setDisjoint(true);
39423971
return DAG.getNode(ISD::OR, ElemDL, XLenVT, A,
3943-
DAG.getNode(ISD::SHL, ElemDL, XLenVT, B, ShtAmt));
3972+
DAG.getNode(ISD::SHL, ElemDL, XLenVT, B, ShtAmt), Flags);
39443973
};
39453974

3946-
SDValue Mask = DAG.getConstant(
3947-
APInt::getLowBitsSet(XLenVT.getSizeInBits(), ElemSizeInBits), DL, XLenVT);
39483975
SmallVector<SDValue> NewOperands;
39493976
NewOperands.reserve(NumElts / 2);
3950-
for (unsigned i = 0; i < VT.getVectorNumElements(); i += 2) {
3951-
SDValue A = Op.getOperand(i);
3952-
SDValue B = Op.getOperand(i + 1);
3953-
// Bias the scheduling of the inserted operations to near the
3954-
// definition of the element - this tends to reduce register
3955-
// pressure overall.
3956-
A = DAG.getNode(ISD::AND, SDLoc(A), XLenVT, A, Mask);
3957-
B = DAG.getNode(ISD::AND, SDLoc(B), XLenVT, B, Mask);
3958-
NewOperands.push_back(pack(A, B));
3959-
}
3977+
for (unsigned i = 0; i < VT.getVectorNumElements(); i += 2)
3978+
NewOperands.push_back(pack(Op.getOperand(i), Op.getOperand(i + 1)));
39603979
assert(NumElts == NewOperands.size() * 2);
39613980
MVT WideVT = MVT::getIntegerVT(ElemSizeInBits * 2);
39623981
MVT WideVecVT = MVT::getVectorVT(WideVT, NumElts / 2);

0 commit comments

Comments
 (0)