Skip to content

Commit d84450b

Browse files
suryajasperantiagainst
authored andcommitted
[mlir][linalg] Canonicalize tensor.extract(linalg.fill)
Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D156008
1 parent 558ab65 commit d84450b

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,12 +730,34 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
730730
}
731731
};
732732

733+
/// Fold tensor.extract(linalg.fill(<input>)) into <input>
734+
struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
735+
public:
736+
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
737+
738+
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
739+
PatternRewriter &rewriter) const override {
740+
// See if tensor input of tensor.extract op is the result of a linalg.fill op.
741+
auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
742+
if (!fillOp)
743+
return failure();
744+
745+
// Get scalar input operand of linalg.fill op.
746+
Value extractedScalar = fillOp.getInputs()[0];
747+
748+
// Replace tensor.extract op with scalar value used to fill the tensor.
749+
rewriter.replaceOp(extractOp, extractedScalar);
750+
return success();
751+
}
752+
};
753+
733754
} // namespace
734755

735756
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
736757
MLIRContext *context) {
737758
results
738-
.add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
759+
.add<FoldFillWithTensorExtract,
760+
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
739761
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
740762
FoldInsertPadIntoFill>(context);
741763
}

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,22 @@ func.func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x
335335
return %1 : tensor<?x?xf32>
336336
}
337337

338+
// -----
339+
// CHECK: func @fold_fill_extract
340+
// CHECK-SAME: %[[ARG0:.+]]: i1
341+
func.func @fold_fill_extract(%arg0 : i1) -> i1 {
342+
%c0 = arith.constant 0 : index
343+
%c1 = arith.constant 1 : index
344+
345+
%empty_dynamic = tensor.empty(%c1) : tensor<1x2x3x?xi1>
346+
%filled = linalg.fill ins(%arg0 : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1>
347+
348+
%extracted = tensor.extract %filled[%c0, %c0, %c0, %c0] : tensor<1x2x3x?xi1>
349+
350+
// CHECK: return %[[ARG0]]
351+
return %extracted : i1
352+
}
353+
338354
// -----
339355

340356
// CHECK: func @fold_self_copy

0 commit comments

Comments
 (0)