Skip to content

Commit 2c4a56c

Browse files
author
Nicolas Vasilache
committed
[mlir][Linalg] NFC - Modernize padding pattern
Differential Revision: https://reviews.llvm.org/D116739
1 parent 43c5e61 commit 2c4a56c

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
688688
/// Apply the `padding` transformation as a pattern.
689689
/// `filter` controls LinalgTransformMarker matching and update when specified.
690690
/// See `padding` for more details.
691-
struct LinalgPaddingPattern : public RewritePattern {
691+
struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
692692
// Entry point to match any LinalgOp OpInterface.
693693
LinalgPaddingPattern(
694694
MLIRContext *context,
@@ -701,7 +701,7 @@ struct LinalgPaddingPattern : public RewritePattern {
701701
LinalgPaddingOptions options = LinalgPaddingOptions(),
702702
LinalgTransformationFilter filter = LinalgTransformationFilter(),
703703
PatternBenefit benefit = 1);
704-
LogicalResult matchAndRewrite(Operation *op,
704+
LogicalResult matchAndRewrite(LinalgOp,
705705
PatternRewriter &rewriter) const override;
706706

707707
private:

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -489,23 +489,24 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
489489
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
490490
MLIRContext *context, LinalgPaddingOptions options,
491491
LinalgTransformationFilter filter, PatternBenefit benefit)
492-
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
492+
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
493493
filter(std::move(filter)), options(std::move(options)) {}
494494

495495
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
496496
StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
497497
LinalgTransformationFilter filter, PatternBenefit benefit)
498-
: RewritePattern(opName, benefit, context, {}), filter(std::move(filter)),
499-
options(std::move(options)) {}
498+
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
499+
filter(std::move(filter)), options(std::move(options)) {
500+
this->filter.addFilter([opName](Operation *op) {
501+
return success(op->getName().getStringRef() == opName);
502+
});
503+
}
500504

501505
LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
502-
Operation *op, PatternRewriter &rewriter) const {
503-
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
504-
if (!linalgOp)
505-
return failure();
506+
LinalgOp linalgOp, PatternRewriter &rewriter) const {
506507
if (!linalgOp.hasTensorSemantics())
507508
return failure();
508-
if (failed(filter.checkAndNotify(rewriter, op)))
509+
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
509510
return failure();
510511

511512
// Pad the operation.
@@ -538,7 +539,7 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
538539
}
539540

540541
// Replace the original operation to pad.
541-
rewriter.replaceOp(op, newResults.getValue());
542+
rewriter.replaceOp(linalgOp, newResults.getValue());
542543
filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
543544
return success();
544545
}

0 commit comments

Comments
 (0)