Skip to content

Commit 7897a94

Browse files
committed
[mlir][vector] Fold extract(shape_cast) for same element count
Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D157930
1 parent 9f37c21 commit 7897a94

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -1808,12 +1808,34 @@ class ExtractOpNonSplatConstantFolder final
18081808
}
18091809
};
18101810

1811+
// Folds extract(shape_cast(..)) into shape_cast when the total element count
1812+
// does not change.
1813+
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
1814+
PatternRewriter &rewriter) {
1815+
auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
1816+
if (!castOp)
1817+
return failure();
1818+
1819+
VectorType sourceType = castOp.getSourceVectorType();
1820+
auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
1821+
if (!targetType)
1822+
return failure();
1823+
1824+
if (sourceType.getNumElements() != targetType.getNumElements())
1825+
return failure();
1826+
1827+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
1828+
castOp.getSource());
1829+
return success();
1830+
}
1831+
18111832
} // namespace
18121833

18131834
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
18141835
MLIRContext *context) {
18151836
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
18161837
ExtractOpFromBroadcast>(context);
1838+
results.add(foldExtractFromShapeCastToShapeCast);
18171839
}
18181840

18191841
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,

mlir/test/Dialect/Vector/canonicalize.mlir

+12
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,18 @@ func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
669669

670670
// -----
671671

672+
// CHECK-LABEL: fold_extract_shapecast_to_shapecast
673+
// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
674+
// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
675+
// CHECK: return %[[R]]
676+
func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
677+
%0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
678+
%r = vector.extract %0[0] : vector<1x12xf32>
679+
return %r : vector<12xf32>
680+
}
681+
682+
// -----
683+
672684
// CHECK-LABEL: dont_fold_expand_collapse
673685
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
674686
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>

0 commit comments

Comments
 (0)