Skip to content

Commit 216f546

Browse files
[SVE] Refactor lowering for fixed length MGATHER/MSCATTER.
Lower fixed length MGATHER/MSCATTER operations to scalable vector equivalents, which are then lowered to SVE specific nodes. This two stage process is in preparation for making scalable vector MGATHER/MSCATTER operations legal. Differential Revision: https://reviews.llvm.org/D125192
1 parent 86fd1c1 commit 216f546

File tree

1 file changed

+98
-69
lines changed

1 file changed

+98
-69
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 98 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4696,33 +4696,65 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
46964696
MGT->getMemOperand(), IndexType, ExtType);
46974697
}
46984698

4699+
// Lower fixed length gather to a scalable equivalent.
4700+
if (VT.isFixedLengthVector()) {
4701+
assert(Subtarget->useSVEForFixedLengthVectors() &&
4702+
"Cannot lower when not using SVE for fixed vectors!");
4703+
4704+
// NOTE: Handle floating-point as if integer then bitcast the result.
4705+
EVT DataVT = VT.changeVectorElementTypeToInteger();
4706+
MemVT = MemVT.changeVectorElementTypeToInteger();
4707+
4708+
// Find the smallest integer fixed length vector we can use for the gather.
4709+
EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
4710+
if (DataVT.getVectorElementType() == MVT::i64 ||
4711+
Index.getValueType().getVectorElementType() == MVT::i64 ||
4712+
Mask.getValueType().getVectorElementType() == MVT::i64)
4713+
PromotedVT = VT.changeVectorElementType(MVT::i64);
4714+
4715+
// Promote vector operands except for passthrough, which we know is either
4716+
// undef or zero, and thus best constructed directly.
4717+
unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4718+
Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
4719+
Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
4720+
4721+
// A promoted result type forces the need for an extending load.
4722+
if (PromotedVT != DataVT && ExtType == ISD::NON_EXTLOAD)
4723+
ExtType = ISD::EXTLOAD;
4724+
4725+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
4726+
4727+
// Convert fixed length vector operands to scalable.
4728+
MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
4729+
Index = convertToScalableVector(DAG, ContainerVT, Index);
4730+
Mask = convertFixedMaskToScalableVector(Mask, DAG);
4731+
PassThru = PassThru->isUndef() ? DAG.getUNDEF(ContainerVT)
4732+
: DAG.getConstant(0, DL, ContainerVT);
4733+
4734+
// Emit equivalent scalable vector gather.
4735+
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
4736+
SDValue Load =
4737+
DAG.getMaskedGather(DAG.getVTList(ContainerVT, MVT::Other), MemVT, DL,
4738+
Ops, MGT->getMemOperand(), IndexType, ExtType);
4739+
4740+
// Extract fixed length data then convert to the required result type.
4741+
SDValue Result = convertFromScalableVector(DAG, PromotedVT, Load);
4742+
Result = DAG.getNode(ISD::TRUNCATE, DL, DataVT, Result);
4743+
if (VT.isFloatingPoint())
4744+
Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
4745+
4746+
return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
4747+
}
4748+
46994749
bool IdxNeedsExtend =
47004750
getGatherScatterIndexIsExtended(Index) ||
47014751
Index.getSimpleValueType().getVectorElementType() == MVT::i32;
47024752

47034753
EVT IndexVT = Index.getSimpleValueType();
47044754
SDValue InputVT = DAG.getValueType(MemVT);
47054755

4706-
bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector();
4707-
4708-
if (IsFixedLength) {
4709-
assert(Subtarget->useSVEForFixedLengthVectors() &&
4710-
"Cannot lower when not using SVE for fixed vectors");
4711-
if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
4712-
IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
4713-
MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
4714-
} else {
4715-
MemVT = getContainerForFixedLengthVector(DAG, MemVT);
4716-
IndexVT = MemVT.changeTypeToInteger();
4717-
}
4718-
InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
4719-
Mask = DAG.getNode(
4720-
ISD::SIGN_EXTEND, DL,
4721-
VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
4722-
}
4723-
47244756
// Handle FP data by using an integer gather and casting the result.
4725-
if (VT.isFloatingPoint() && !IsFixedLength)
4757+
if (VT.isFloatingPoint())
47264758
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
47274759

47284760
SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other);
@@ -4737,25 +4769,11 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
47374769
if (ExtType == ISD::SEXTLOAD)
47384770
Opcode = getSignExtendedGatherOpcode(Opcode);
47394771

4740-
if (IsFixedLength) {
4741-
if (Index.getSimpleValueType().isFixedLengthVector())
4742-
Index = convertToScalableVector(DAG, IndexVT, Index);
4743-
if (BasePtr.getSimpleValueType().isFixedLengthVector())
4744-
BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
4745-
Mask = convertFixedMaskToScalableVector(Mask, DAG);
4746-
}
4747-
47484772
SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT};
47494773
SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops);
47504774
Chain = Result.getValue(1);
47514775

4752-
if (IsFixedLength) {
4753-
Result = convertFromScalableVector(
4754-
DAG, VT.changeVectorElementType(IndexVT.getVectorElementType()),
4755-
Result);
4756-
Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result);
4757-
Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
4758-
} else if (VT.isFloatingPoint())
4776+
if (VT.isFloatingPoint())
47594777
Result = getSVESafeBitCast(VT, Result, DAG);
47604778

47614779
return DAG.getMergeValues({Result, Chain}, DL);
@@ -4775,6 +4793,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
47754793
EVT VT = StoreVal.getValueType();
47764794
EVT MemVT = MSC->getMemoryVT();
47774795
ISD::MemIndexType IndexType = MSC->getIndexType();
4796+
bool Truncating = MSC->isTruncatingStore();
47784797

47794798
bool IsScaled = MSC->isIndexScaled();
47804799
bool IsSigned = MSC->isIndexSigned();
@@ -4791,42 +4810,60 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
47914810

47924811
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
47934812
return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
4794-
MSC->getMemOperand(), IndexType,
4795-
MSC->isTruncatingStore());
4813+
MSC->getMemOperand(), IndexType, Truncating);
4814+
}
4815+
4816+
// Lower fixed length scatter to a scalable equivalent.
4817+
if (VT.isFixedLengthVector()) {
4818+
assert(Subtarget->useSVEForFixedLengthVectors() &&
4819+
"Cannot lower when not using SVE for fixed vectors!");
4820+
4821+
// Once bitcast we treat floating-point scatters as if integer.
4822+
if (VT.isFloatingPoint()) {
4823+
VT = VT.changeVectorElementTypeToInteger();
4824+
MemVT = MemVT.changeVectorElementTypeToInteger();
4825+
StoreVal = DAG.getNode(ISD::BITCAST, DL, VT, StoreVal);
4826+
}
4827+
4828+
// Find the smallest integer fixed length vector we can use for the scatter.
4829+
EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
4830+
if (VT.getVectorElementType() == MVT::i64 ||
4831+
Index.getValueType().getVectorElementType() == MVT::i64 ||
4832+
Mask.getValueType().getVectorElementType() == MVT::i64)
4833+
PromotedVT = VT.changeVectorElementType(MVT::i64);
4834+
4835+
// Promote vector operands.
4836+
unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4837+
Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
4838+
Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
4839+
StoreVal = DAG.getNode(ISD::ANY_EXTEND, DL, PromotedVT, StoreVal);
4840+
4841+
// A promoted value type forces the need for a truncating store.
4842+
if (PromotedVT != VT)
4843+
Truncating = true;
4844+
4845+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
4846+
4847+
// Convert fixed length vector operands to scalable.
4848+
MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
4849+
Index = convertToScalableVector(DAG, ContainerVT, Index);
4850+
Mask = convertFixedMaskToScalableVector(Mask, DAG);
4851+
StoreVal = convertToScalableVector(DAG, ContainerVT, StoreVal);
4852+
4853+
// Emit equivalent scalable vector scatter.
4854+
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
4855+
return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
4856+
MSC->getMemOperand(), IndexType, Truncating);
47964857
}
47974858

47984859
bool NeedsExtend =
47994860
getGatherScatterIndexIsExtended(Index) ||
48004861
Index.getSimpleValueType().getVectorElementType() == MVT::i32;
48014862

4802-
EVT IndexVT = Index.getSimpleValueType();
48034863
SDVTList VTs = DAG.getVTList(MVT::Other);
48044864
SDValue InputVT = DAG.getValueType(MemVT);
48054865

4806-
bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector();
4807-
4808-
if (IsFixedLength) {
4809-
assert(Subtarget->useSVEForFixedLengthVectors() &&
4810-
"Cannot lower when not using SVE for fixed vectors");
4811-
if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
4812-
IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
4813-
MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
4814-
} else {
4815-
MemVT = getContainerForFixedLengthVector(DAG, MemVT);
4816-
IndexVT = MemVT.changeTypeToInteger();
4817-
}
4818-
InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
4819-
4820-
StoreVal =
4821-
DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), StoreVal);
4822-
StoreVal = DAG.getNode(
4823-
ISD::ANY_EXTEND, DL,
4824-
VT.changeVectorElementType(IndexVT.getVectorElementType()), StoreVal);
4825-
StoreVal = convertToScalableVector(DAG, IndexVT, StoreVal);
4826-
Mask = DAG.getNode(
4827-
ISD::SIGN_EXTEND, DL,
4828-
VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
4829-
} else if (VT.isFloatingPoint()) {
4866+
if (VT.isFloatingPoint()) {
48304867
// Handle FP data by casting the data so an integer scatter can be used.
48314868
EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
48324869
StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
@@ -4840,14 +4877,6 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
48404877
selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
48414878
/*isGather=*/false, DAG);
48424879

4843-
if (IsFixedLength) {
4844-
if (Index.getSimpleValueType().isFixedLengthVector())
4845-
Index = convertToScalableVector(DAG, IndexVT, Index);
4846-
if (BasePtr.getSimpleValueType().isFixedLengthVector())
4847-
BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
4848-
Mask = convertFixedMaskToScalableVector(Mask, DAG);
4849-
}
4850-
48514880
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
48524881
return DAG.getNode(Opcode, DL, VTs, Ops);
48534882
}

0 commit comments

Comments
 (0)