@@ -3795,6 +3795,95 @@ OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
3795
3795
return OpFoldResult ();
3796
3796
}
3797
3797
3798
+ namespace {
3799
+ // / Pattern to rewrite a subtensor_insert op with constant arguments.
3800
+ class SubTensorInsertOpConstantArgumentFolder final
3801
+ : public OpRewritePattern<SubTensorInsertOp> {
3802
+ public:
3803
+ using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
3804
+
3805
+ LogicalResult matchAndRewrite (SubTensorInsertOp subTensorInsertOp,
3806
+ PatternRewriter &rewriter) const override {
3807
+ // No constant operand, just return.
3808
+ if (llvm::none_of (subTensorInsertOp.getOperands (), [](Value operand) {
3809
+ return matchPattern (operand, m_ConstantIndex ());
3810
+ }))
3811
+ return failure ();
3812
+
3813
+ // At least one of offsets/sizes/strides is a new constant.
3814
+ // Form the new list of operands and constant attributes from the existing.
3815
+ SmallVector<OpFoldResult> mixedOffsets (subTensorInsertOp.getMixedOffsets ());
3816
+ SmallVector<OpFoldResult> mixedSizes (subTensorInsertOp.getMixedSizes ());
3817
+ SmallVector<OpFoldResult> mixedStrides (subTensorInsertOp.getMixedStrides ());
3818
+ canonicalizeSubViewPart (mixedOffsets, ShapedType::isDynamicStrideOrOffset);
3819
+ canonicalizeSubViewPart (mixedSizes, ShapedType::isDynamic);
3820
+ canonicalizeSubViewPart (mixedStrides, ShapedType::isDynamicStrideOrOffset);
3821
+
3822
+ // Create the new op in canonical form.
3823
+ Value source = subTensorInsertOp.source ();
3824
+ RankedTensorType sourceType = source.getType ().cast <RankedTensorType>();
3825
+ SmallVector<int64_t , 4 > shape = llvm::to_vector<4 >(
3826
+ llvm::map_range (mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t {
3827
+ if (auto attr = valueOrAttr.dyn_cast <Attribute>())
3828
+ return attr.cast <IntegerAttr>().getInt ();
3829
+ return ShapedType::kDynamicSize ;
3830
+ }));
3831
+ RankedTensorType newSourceType =
3832
+ RankedTensorType::get (shape, sourceType.getElementType ());
3833
+ Location loc = subTensorInsertOp.getLoc ();
3834
+ if (sourceType != newSourceType)
3835
+ source = rewriter.create <tensor::CastOp>(loc, newSourceType, source);
3836
+ rewriter.replaceOpWithNewOp <SubTensorInsertOp>(
3837
+ subTensorInsertOp, source, subTensorInsertOp.dest (), mixedOffsets,
3838
+ mixedSizes, mixedStrides);
3839
+ return success ();
3840
+ }
3841
+ };
3842
+
3843
+ // / Fold tensor_casts with subtensor_insert operations.
3844
+ struct SubTensorInsertOpCastFolder final
3845
+ : public OpRewritePattern<SubTensorInsertOp> {
3846
+ using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
3847
+
3848
+ LogicalResult matchAndRewrite (SubTensorInsertOp subTensorOp,
3849
+ PatternRewriter &rewriter) const override {
3850
+ if (llvm::any_of (subTensorOp.getOperands (), [](Value operand) {
3851
+ return matchPattern (operand, m_ConstantIndex ());
3852
+ }))
3853
+ return failure ();
3854
+
3855
+ auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
3856
+ auto castOp = v.getDefiningOp <tensor::CastOp>();
3857
+ if (!castOp || !canFoldIntoConsumerOp (castOp))
3858
+ return llvm::None;
3859
+ return castOp.source ();
3860
+ };
3861
+ Optional<Value> sourceCastSource = getSourceOfCastOp (subTensorOp.source ());
3862
+ Optional<Value> destCastSource = getSourceOfCastOp (subTensorOp.dest ());
3863
+ if (!sourceCastSource && !destCastSource &&
3864
+ subTensorOp.dest ().getType () == subTensorOp.getResult ().getType ())
3865
+ return failure ();
3866
+
3867
+ auto newOp = rewriter.create <SubTensorInsertOp>(
3868
+ subTensorOp.getLoc (),
3869
+ (sourceCastSource ? *sourceCastSource : subTensorOp.source ()),
3870
+ (destCastSource ? *destCastSource : subTensorOp.dest ()),
3871
+ subTensorOp.getMixedOffsets (), subTensorOp.getMixedSizes (),
3872
+ subTensorOp.getMixedStrides ());
3873
+
3874
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(subTensorOp,
3875
+ subTensorOp.getType (), newOp);
3876
+ return success ();
3877
+ }
3878
+ };
3879
+ } // namespace
3880
+
3881
+ void SubTensorInsertOp::getCanonicalizationPatterns (
3882
+ OwningRewritePatternList &results, MLIRContext *context) {
3883
+ results.insert <SubTensorInsertOpConstantArgumentFolder,
3884
+ SubTensorInsertOpCastFolder>(context);
3885
+ }
3886
+
3798
3887
// ===----------------------------------------------------------------------===//
3799
3888
// TensorLoadOp
3800
3889
// ===----------------------------------------------------------------------===//
0 commit comments