@@ -489,23 +489,24 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
489
489
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern (
490
490
MLIRContext *context, LinalgPaddingOptions options,
491
491
LinalgTransformationFilter filter, PatternBenefit benefit)
492
- : RewritePattern(MatchAnyOpTypeTag() , benefit, context ),
492
+ : OpInterfaceRewritePattern<LinalgOp>(context , benefit),
493
493
filter(std::move(filter)), options(std::move(options)) {}
494
494
495
495
mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern (
496
496
StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
497
497
LinalgTransformationFilter filter, PatternBenefit benefit)
498
- : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)),
499
- options(std::move(options)) {}
498
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
499
+ filter(std::move(filter)), options(std::move(options)) {
500
+ this ->filter .addFilter ([opName](Operation *op) {
501
+ return success (op->getName ().getStringRef () == opName);
502
+ });
503
+ }
500
504
501
505
LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite (
502
- Operation *op, PatternRewriter &rewriter) const {
503
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
504
- if (!linalgOp)
505
- return failure ();
506
+ LinalgOp linalgOp, PatternRewriter &rewriter) const {
506
507
if (!linalgOp.hasTensorSemantics ())
507
508
return failure ();
508
- if (failed (filter.checkAndNotify (rewriter, op )))
509
+ if (failed (filter.checkAndNotify (rewriter, linalgOp )))
509
510
return failure ();
510
511
511
512
// Pad the operation.
@@ -538,7 +539,7 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
538
539
}
539
540
540
541
// Replace the original operation to pad.
541
- rewriter.replaceOp (op , newResults.getValue ());
542
+ rewriter.replaceOp (linalgOp , newResults.getValue ());
542
543
filter.replaceLinalgTransformationFilter (rewriter, paddedOp);
543
544
return success ();
544
545
}
0 commit comments