Skip to content

Commit 7ae78a6

Browse files
[mlir][vector]add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp. (llvm#124399)
add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp.
1 parent 6b52fb2 commit 7ae78a6

File tree

4 files changed

+114
-2
lines changed

4 files changed

+114
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,45 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19891989
return fromElementsOp.getElements()[flatIndex];
19901990
}
19911991

1992+
/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
1993+
/// then fold it.
1994+
template <typename OpType, typename AdaptorType>
1995+
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
1996+
SmallVectorImpl<Value> &operands) {
1997+
std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
1998+
OperandRange dynamicPosition = op.getDynamicPosition();
1999+
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2000+
2001+
// If the dynamic operands is empty, it is returned directly.
2002+
if (!dynamicPosition.size())
2003+
return {};
2004+
2005+
// `index` is used to iterate over the `dynamicPosition`.
2006+
unsigned index = 0;
2007+
2008+
// `opChange` is a flag. If it is true, it means to update `op` in place.
2009+
bool opChange = false;
2010+
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2011+
if (!ShapedType::isDynamic(staticPosition[i]))
2012+
continue;
2013+
Attribute positionAttr = dynamicPositionAttr[index];
2014+
Value position = dynamicPosition[index++];
2015+
if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2016+
staticPosition[i] = attr.getInt();
2017+
opChange = true;
2018+
continue;
2019+
}
2020+
operands.push_back(position);
2021+
}
2022+
2023+
if (opChange) {
2024+
op.setStaticPosition(staticPosition);
2025+
op.getOperation()->setOperands(operands);
2026+
return op.getResult();
2027+
}
2028+
return {};
2029+
}
2030+
19922031
/// Fold an insert or extract operation into an poison value when a poison index
19932032
/// is found at any dimension of the static position.
19942033
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
@@ -2035,6 +2074,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20352074
return val;
20362075
if (auto val = foldScalarExtractFromFromElements(*this))
20372076
return val;
2077+
SmallVector<Value> operands = {getVector()};
2078+
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2079+
return val;
20382080
return OpFoldResult();
20392081
}
20402082

@@ -3094,6 +3136,9 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
30943136
// (type mismatch).
30953137
if (getNumIndices() == 0 && getSourceType() == getType())
30963138
return getSource();
3139+
SmallVector<Value> operands = {getSource(), getDest()};
3140+
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
3141+
return val;
30973142
if (auto res = foldPoisonIndexInsertExtractOp(
30983143
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
30993144
return res;

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,25 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
530530

531531
// -----
532532

533+
func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<32x1xi32>) -> i32 {
534+
%0 = arith.constant 0 : index
535+
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
536+
return %1 : i32
537+
}
538+
539+
// At compile time, since the indices of extractOp are constants,
540+
// they will be collapsed and folded away; therefore, the lowering works.
541+
542+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
543+
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
544+
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
545+
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
546+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
547+
// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
548+
// CHECK: return %[[RES]] : i32
549+
550+
// -----
551+
533552
//===----------------------------------------------------------------------===//
534553
// vector.insertelement
535554
//===----------------------------------------------------------------------===//
@@ -781,6 +800,29 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
781800

782801
// -----
783802

803+
func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
804+
%0 = arith.constant 0 : index
805+
%1 = arith.constant 1 : i32
806+
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
807+
return %res : vector<4x1xi32>
808+
}
809+
810+
// At compile time, since the indices of insertOp are constants,
811+
// they will be collapsed and folded away; therefore, the lowering works.
812+
813+
// CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
814+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
815+
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
816+
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
817+
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
818+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
819+
// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
820+
// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
821+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
822+
// CHECK: return %[[RES]] : vector<4x1xi32>
823+
824+
// -----
825+
784826
//===----------------------------------------------------------------------===//
785827
// vector.type_cast
786828
//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3171,3 +3171,29 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
31713171
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
31723172
return
31733173
}
3174+
3175+
// -----
3176+
3177+
// CHECK-LABEL: @fold_extract_constant_indices
3178+
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
3179+
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
3180+
// CHECK: return %[[RES]] : i32
3181+
func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
3182+
%0 = arith.constant 0 : index
3183+
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
3184+
return %1 : i32
3185+
}
3186+
3187+
// -----
3188+
3189+
// CHECK-LABEL: @fold_insert_constant_indices
3190+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
3191+
// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
3192+
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
3193+
// CHECK: return %[[RES]] : vector<4x1xi32>
3194+
func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
3195+
%0 = arith.constant 0 : index
3196+
%1 = arith.constant 1 : i32
3197+
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
3198+
return %res : vector<4x1xi32>
3199+
}

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
778778

779779
// CHECK-PROP-LABEL: func.func @vector_extract_1d(
780780
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
781-
// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
782781
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
783782
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
784783
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
785784
// CHECK-PROP: }
786-
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
785+
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
787786
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
788787
// CHECK-PROP: return %[[SHUFFLED]] : f32
789788
func.func @vector_extract_1d(%laneid: index) -> (f32) {

0 commit comments

Comments
 (0)