Skip to content

Commit 635ab51

Browse files
authored
[VectorCombine] Fold vector.interleave2 with two constant splats (llvm#125144)
If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32> <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first before casting it back into `<vscale x 16 x i32>`.
1 parent d810c74 commit 635ab51

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

+43
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class VectorCombine {
126126
bool foldShuffleFromReductions(Instruction &I);
127127
bool foldCastFromReductions(Instruction &I);
128128
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
129+
bool foldInterleaveIntrinsics(Instruction &I);
129130
bool shrinkType(Instruction &I);
130131

131132
void replaceValue(Value &Old, Value &New) {
@@ -3204,6 +3205,47 @@ bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
32043205
return true;
32053206
}
32063207

3208+
/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
3209+
/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
3210+
/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
3211+
/// before casting it back into `<vscale x 16 x i32>`.
3212+
bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
3213+
const APInt *SplatVal0, *SplatVal1;
3214+
if (!match(&I, m_Intrinsic<Intrinsic::vector_interleave2>(
3215+
m_APInt(SplatVal0), m_APInt(SplatVal1))))
3216+
return false;
3217+
3218+
LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
3219+
<< "\n");
3220+
3221+
auto *VTy =
3222+
cast<VectorType>(cast<IntrinsicInst>(I).getArgOperand(0)->getType());
3223+
auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
3224+
unsigned Width = VTy->getElementType()->getIntegerBitWidth();
3225+
3226+
// Just in case the cost of interleave2 intrinsic and bitcast are both
3227+
// invalid, in which case we want to bail out, we use <= rather
3228+
// than < here. Even they both have valid and equal costs, it's probably
3229+
// not a good idea to emit a high-cost constant splat.
3230+
if (TTI.getInstructionCost(&I, CostKind) <=
3231+
TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy,
3232+
TTI::CastContextHint::None, CostKind)) {
3233+
LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
3234+
<< *I.getType() << " is too high.\n");
3235+
return false;
3236+
}
3237+
3238+
APInt NewSplatVal = SplatVal1->zext(Width * 2);
3239+
NewSplatVal <<= Width;
3240+
NewSplatVal |= SplatVal0->zext(Width * 2);
3241+
auto *NewSplat = ConstantVector::getSplat(
3242+
ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal));
3243+
3244+
IRBuilder<> Builder(&I);
3245+
replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType()));
3246+
return true;
3247+
}
3248+
32073249
/// This is the entry point for all transforms. Pass manager differences are
32083250
/// handled in the callers of this function.
32093251
bool VectorCombine::run() {
@@ -3248,6 +3290,7 @@ bool VectorCombine::run() {
32483290
MadeChange |= scalarizeBinopOrCmp(I);
32493291
MadeChange |= scalarizeLoadExtract(I);
32503292
MadeChange |= scalarizeVPIntrinsic(I);
3293+
MadeChange |= foldInterleaveIntrinsics(I);
32513294
}
32523295

32533296
if (Opcode == Instruction::Store)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -mtriple=riscv64 -mattr=+v %s -passes=vector-combine | FileCheck %s
3+
; RUN: opt -S -mtriple=riscv32 -mattr=+v %s -passes=vector-combine | FileCheck %s
4+
5+
; We should not form a i128 vector.
6+
7+
define void @interleave2_const_splat_nxv8i64(ptr %dst) {
8+
; CHECK-LABEL: define void @interleave2_const_splat_nxv8i64(
9+
; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
10+
; CHECK-NEXT: [[INTERLEAVE2:%.*]] = call <vscale x 8 x i64> @llvm.vector.interleave2.nxv8i64(<vscale x 4 x i64> splat (i64 666), <vscale x 4 x i64> splat (i64 777))
11+
; CHECK-NEXT: call void @llvm.vp.store.nxv8i64.p0(<vscale x 8 x i64> [[INTERLEAVE2]], ptr [[DST]], <vscale x 8 x i1> splat (i1 true), i32 88)
12+
; CHECK-NEXT: ret void
13+
;
14+
%interleave2 = call <vscale x 8 x i64> @llvm.vector.interleave2.nxv8i64(<vscale x 4 x i64> splat (i64 666), <vscale x 4 x i64> splat (i64 777))
15+
call void @llvm.vp.store.nxv8i64.p0(<vscale x 8 x i64> %interleave2, ptr %dst, <vscale x 8 x i1> splat (i1 true), i32 88)
16+
ret void
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -mtriple=riscv64 -mattr=+v %s -passes=vector-combine | FileCheck %s
3+
; RUN: opt -S -mtriple=riscv32 -mattr=+v %s -passes=vector-combine | FileCheck %s
4+
; RUN: opt -S -mtriple=riscv64 -mattr=+zve32x %s -passes=vector-combine | FileCheck %s --check-prefix=ZVE32X
5+
6+
define void @interleave2_const_splat_nxv16i32(ptr %dst) {
7+
; CHECK-LABEL: define void @interleave2_const_splat_nxv16i32(
8+
; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
9+
; CHECK-NEXT: call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> bitcast (<vscale x 8 x i64> splat (i64 3337189589658) to <vscale x 16 x i32>), ptr [[DST]], <vscale x 16 x i1> splat (i1 true), i32 88)
10+
; CHECK-NEXT: ret void
11+
;
12+
; ZVE32X-LABEL: define void @interleave2_const_splat_nxv16i32(
13+
; ZVE32X-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
14+
; ZVE32X-NEXT: [[INTERLEAVE2:%.*]] = call <vscale x 16 x i32> @llvm.vector.interleave2.nxv16i32(<vscale x 8 x i32> splat (i32 666), <vscale x 8 x i32> splat (i32 777))
15+
; ZVE32X-NEXT: call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> [[INTERLEAVE2]], ptr [[DST]], <vscale x 16 x i1> splat (i1 true), i32 88)
16+
; ZVE32X-NEXT: ret void
17+
;
18+
%interleave2 = call <vscale x 16 x i32> @llvm.vector.interleave2.nxv16i32(<vscale x 8 x i32> splat (i32 666), <vscale x 8 x i32> splat (i32 777))
19+
call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> %interleave2, ptr %dst, <vscale x 16 x i1> splat (i1 true), i32 88)
20+
ret void
21+
}

0 commit comments

Comments
 (0)