File tree 2 files changed +34
-0
lines changed
2 files changed +34
-0
lines changed Original file line number Diff line number Diff line change @@ -1808,12 +1808,34 @@ class ExtractOpNonSplatConstantFolder final
1808
1808
}
1809
1809
};
1810
1810
1811
+ // Folds extract(shape_cast(..)) into shape_cast when the total element count
1812
+ // does not change.
1813
+ LogicalResult foldExtractFromShapeCastToShapeCast (ExtractOp extractOp,
1814
+ PatternRewriter &rewriter) {
1815
+ auto castOp = extractOp.getVector ().getDefiningOp <ShapeCastOp>();
1816
+ if (!castOp)
1817
+ return failure ();
1818
+
1819
+ VectorType sourceType = castOp.getSourceVectorType ();
1820
+ auto targetType = dyn_cast<VectorType>(extractOp.getResult ().getType ());
1821
+ if (!targetType)
1822
+ return failure ();
1823
+
1824
+ if (sourceType.getNumElements () != targetType.getNumElements ())
1825
+ return failure ();
1826
+
1827
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(extractOp, targetType,
1828
+ castOp.getSource ());
1829
+ return success ();
1830
+ }
1831
+
1811
1832
} // namespace
1812
1833
1813
1834
void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
1814
1835
MLIRContext *context) {
1815
1836
results.add <ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
1816
1837
ExtractOpFromBroadcast>(context);
1838
+ results.add (foldExtractFromShapeCastToShapeCast);
1817
1839
}
1818
1840
1819
1841
static void populateFromInt64AttrArray (ArrayAttr arrayAttr,
Original file line number Diff line number Diff line change @@ -669,6 +669,18 @@ func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
669
669
670
670
// -----
671
671
672
+ // CHECK-LABEL: fold_extract_shapecast_to_shapecast
673
+ // CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
674
+ // CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
675
+ // CHECK: return %[[R]]
676
+ func.func @fold_extract_shapecast_to_shapecast (%arg0 : vector <3 x4 xf32 >) -> vector <12 xf32 > {
677
+ %0 = vector.shape_cast %arg0 : vector <3 x4 xf32 > to vector <1 x12 xf32 >
678
+ %r = vector.extract %0 [0 ] : vector <1 x12 xf32 >
679
+ return %r : vector <12 xf32 >
680
+ }
681
+
682
+ // -----
683
+
672
684
// CHECK-LABEL: dont_fold_expand_collapse
673
685
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
674
686
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
You can’t perform that action at this time.
0 commit comments