@@ -4696,33 +4696,65 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
4696
4696
MGT->getMemOperand(), IndexType, ExtType);
4697
4697
}
4698
4698
4699
+ // Lower fixed length gather to a scalable equivalent.
4700
+ if (VT.isFixedLengthVector()) {
4701
+ assert(Subtarget->useSVEForFixedLengthVectors() &&
4702
+ "Cannot lower when not using SVE for fixed vectors!");
4703
+
4704
+ // NOTE: Handle floating-point as if integer then bitcast the result.
4705
+ EVT DataVT = VT.changeVectorElementTypeToInteger();
4706
+ MemVT = MemVT.changeVectorElementTypeToInteger();
4707
+
4708
+ // Find the smallest integer fixed length vector we can use for the gather.
4709
+ EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
4710
+ if (DataVT.getVectorElementType() == MVT::i64 ||
4711
+ Index.getValueType().getVectorElementType() == MVT::i64 ||
4712
+ Mask.getValueType().getVectorElementType() == MVT::i64)
4713
+ PromotedVT = VT.changeVectorElementType(MVT::i64);
4714
+
4715
+ // Promote vector operands except for passthrough, which we know is either
4716
+ // undef or zero, and thus best constructed directly.
4717
+ unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4718
+ Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
4719
+ Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
4720
+
4721
+ // A promoted result type forces the need for an extending load.
4722
+ if (PromotedVT != DataVT && ExtType == ISD::NON_EXTLOAD)
4723
+ ExtType = ISD::EXTLOAD;
4724
+
4725
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
4726
+
4727
+ // Convert fixed length vector operands to scalable.
4728
+ MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
4729
+ Index = convertToScalableVector(DAG, ContainerVT, Index);
4730
+ Mask = convertFixedMaskToScalableVector(Mask, DAG);
4731
+ PassThru = PassThru->isUndef() ? DAG.getUNDEF(ContainerVT)
4732
+ : DAG.getConstant(0, DL, ContainerVT);
4733
+
4734
+ // Emit equivalent scalable vector gather.
4735
+ SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
4736
+ SDValue Load =
4737
+ DAG.getMaskedGather(DAG.getVTList(ContainerVT, MVT::Other), MemVT, DL,
4738
+ Ops, MGT->getMemOperand(), IndexType, ExtType);
4739
+
4740
+ // Extract fixed length data then convert to the required result type.
4741
+ SDValue Result = convertFromScalableVector(DAG, PromotedVT, Load);
4742
+ Result = DAG.getNode(ISD::TRUNCATE, DL, DataVT, Result);
4743
+ if (VT.isFloatingPoint())
4744
+ Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
4745
+
4746
+ return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
4747
+ }
4748
+
4699
4749
bool IdxNeedsExtend =
4700
4750
getGatherScatterIndexIsExtended(Index) ||
4701
4751
Index.getSimpleValueType().getVectorElementType() == MVT::i32;
4702
4752
4703
4753
EVT IndexVT = Index.getSimpleValueType();
4704
4754
SDValue InputVT = DAG.getValueType(MemVT);
4705
4755
4706
- bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector();
4707
-
4708
- if (IsFixedLength) {
4709
- assert(Subtarget->useSVEForFixedLengthVectors() &&
4710
- "Cannot lower when not using SVE for fixed vectors");
4711
- if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
4712
- IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
4713
- MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
4714
- } else {
4715
- MemVT = getContainerForFixedLengthVector(DAG, MemVT);
4716
- IndexVT = MemVT.changeTypeToInteger();
4717
- }
4718
- InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
4719
- Mask = DAG.getNode(
4720
- ISD::SIGN_EXTEND, DL,
4721
- VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
4722
- }
4723
-
4724
4756
// Handle FP data by using an integer gather and casting the result.
4725
- if (VT.isFloatingPoint() && !IsFixedLength )
4757
+ if (VT.isFloatingPoint())
4726
4758
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
4727
4759
4728
4760
SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other);
@@ -4737,25 +4769,11 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
4737
4769
if (ExtType == ISD::SEXTLOAD)
4738
4770
Opcode = getSignExtendedGatherOpcode(Opcode);
4739
4771
4740
- if (IsFixedLength) {
4741
- if (Index.getSimpleValueType().isFixedLengthVector())
4742
- Index = convertToScalableVector(DAG, IndexVT, Index);
4743
- if (BasePtr.getSimpleValueType().isFixedLengthVector())
4744
- BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
4745
- Mask = convertFixedMaskToScalableVector(Mask, DAG);
4746
- }
4747
-
4748
4772
SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT};
4749
4773
SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops);
4750
4774
Chain = Result.getValue(1);
4751
4775
4752
- if (IsFixedLength) {
4753
- Result = convertFromScalableVector(
4754
- DAG, VT.changeVectorElementType(IndexVT.getVectorElementType()),
4755
- Result);
4756
- Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result);
4757
- Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
4758
- } else if (VT.isFloatingPoint())
4776
+ if (VT.isFloatingPoint())
4759
4777
Result = getSVESafeBitCast(VT, Result, DAG);
4760
4778
4761
4779
return DAG.getMergeValues({Result, Chain}, DL);
@@ -4775,6 +4793,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
4775
4793
EVT VT = StoreVal.getValueType();
4776
4794
EVT MemVT = MSC->getMemoryVT();
4777
4795
ISD::MemIndexType IndexType = MSC->getIndexType();
4796
+ bool Truncating = MSC->isTruncatingStore();
4778
4797
4779
4798
bool IsScaled = MSC->isIndexScaled();
4780
4799
bool IsSigned = MSC->isIndexSigned();
@@ -4791,42 +4810,60 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
4791
4810
4792
4811
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
4793
4812
return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
4794
- MSC->getMemOperand(), IndexType,
4795
- MSC->isTruncatingStore());
4813
+ MSC->getMemOperand(), IndexType, Truncating);
4814
+ }
4815
+
4816
+ // Lower fixed length scatter to a scalable equivalent.
4817
+ if (VT.isFixedLengthVector()) {
4818
+ assert(Subtarget->useSVEForFixedLengthVectors() &&
4819
+ "Cannot lower when not using SVE for fixed vectors!");
4820
+
4821
+ // Once bitcast we treat floating-point scatters as if integer.
4822
+ if (VT.isFloatingPoint()) {
4823
+ VT = VT.changeVectorElementTypeToInteger();
4824
+ MemVT = MemVT.changeVectorElementTypeToInteger();
4825
+ StoreVal = DAG.getNode(ISD::BITCAST, DL, VT, StoreVal);
4826
+ }
4827
+
4828
+ // Find the smallest integer fixed length vector we can use for the scatter.
4829
+ EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
4830
+ if (VT.getVectorElementType() == MVT::i64 ||
4831
+ Index.getValueType().getVectorElementType() == MVT::i64 ||
4832
+ Mask.getValueType().getVectorElementType() == MVT::i64)
4833
+ PromotedVT = VT.changeVectorElementType(MVT::i64);
4834
+
4835
+ // Promote vector operands.
4836
+ unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4837
+ Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
4838
+ Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
4839
+ StoreVal = DAG.getNode(ISD::ANY_EXTEND, DL, PromotedVT, StoreVal);
4840
+
4841
+ // A promoted value type forces the need for a truncating store.
4842
+ if (PromotedVT != VT)
4843
+ Truncating = true;
4844
+
4845
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
4846
+
4847
+ // Convert fixed length vector operands to scalable.
4848
+ MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
4849
+ Index = convertToScalableVector(DAG, ContainerVT, Index);
4850
+ Mask = convertFixedMaskToScalableVector(Mask, DAG);
4851
+ StoreVal = convertToScalableVector(DAG, ContainerVT, StoreVal);
4852
+
4853
+ // Emit equivalent scalable vector scatter.
4854
+ SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
4855
+ return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
4856
+ MSC->getMemOperand(), IndexType, Truncating);
4796
4857
}
4797
4858
4798
4859
bool NeedsExtend =
4799
4860
getGatherScatterIndexIsExtended(Index) ||
4800
4861
Index.getSimpleValueType().getVectorElementType() == MVT::i32;
4801
4862
4802
- EVT IndexVT = Index.getSimpleValueType();
4803
4863
SDVTList VTs = DAG.getVTList(MVT::Other);
4804
4864
SDValue InputVT = DAG.getValueType(MemVT);
4805
4865
4806
- bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector();
4807
-
4808
- if (IsFixedLength) {
4809
- assert(Subtarget->useSVEForFixedLengthVectors() &&
4810
- "Cannot lower when not using SVE for fixed vectors");
4811
- if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
4812
- IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
4813
- MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
4814
- } else {
4815
- MemVT = getContainerForFixedLengthVector(DAG, MemVT);
4816
- IndexVT = MemVT.changeTypeToInteger();
4817
- }
4818
- InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
4819
-
4820
- StoreVal =
4821
- DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), StoreVal);
4822
- StoreVal = DAG.getNode(
4823
- ISD::ANY_EXTEND, DL,
4824
- VT.changeVectorElementType(IndexVT.getVectorElementType()), StoreVal);
4825
- StoreVal = convertToScalableVector(DAG, IndexVT, StoreVal);
4826
- Mask = DAG.getNode(
4827
- ISD::SIGN_EXTEND, DL,
4828
- VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
4829
- } else if (VT.isFloatingPoint()) {
4866
+ if (VT.isFloatingPoint()) {
4830
4867
// Handle FP data by casting the data so an integer scatter can be used.
4831
4868
EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
4832
4869
StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
@@ -4840,14 +4877,6 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
4840
4877
selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
4841
4878
/*isGather=*/false, DAG);
4842
4879
4843
- if (IsFixedLength) {
4844
- if (Index.getSimpleValueType().isFixedLengthVector())
4845
- Index = convertToScalableVector(DAG, IndexVT, Index);
4846
- if (BasePtr.getSimpleValueType().isFixedLengthVector())
4847
- BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
4848
- Mask = convertFixedMaskToScalableVector(Mask, DAG);
4849
- }
4850
-
4851
4880
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
4852
4881
return DAG.getNode(Opcode, DL, VTs, Ops);
4853
4882
}
0 commit comments