Skip to content

Commit 5267f5e

Browse files
committed
[mlir] Add a hook to PatternRewriter to allow for patterns to notify why a match failed.
Summary: This revision adds a new hook, `notifyMatchFailure`, that allows for notifying the rewriter that a match failure is coming with the provided reason. This hook takes as a parameter a callback that fills a `Diagnostic` instance with the reason why the match failed. This allows for the rewriter to decide how this information can be displayed to the end-user, and may completely ignore it if desired(opt mode). For now, DialectConversion is updated to include this information in the debug output. Differential Revision: https://reviews.llvm.org/D76203
1 parent 1bf0c99 commit 5267f5e

File tree

4 files changed

+35
-1
lines changed

4 files changed

+35
-1
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,23 @@ class PatternRewriter : public OpBuilder {
334334
finalizeRootUpdate(root);
335335
}
336336

337+
/// Notify the pattern rewriter that the pattern is failing to match the given
338+
/// operation, and provide a callback to populate a diagnostic with the reason
339+
/// why the failure occurred. This method allows for derived rewriters to
340+
/// optionally hook into the reason why a pattern failed, and display it to
341+
/// users.
342+
virtual LogicalResult
343+
notifyMatchFailure(Operation *op,
344+
function_ref<void(Diagnostic &)> reasonCallback) {
345+
return failure();
346+
}
347+
LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) {
348+
return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; });
349+
}
350+
LogicalResult notifyMatchFailure(Operation *op, const char *msg) {
351+
return notifyMatchFailure(op, Twine(msg));
352+
}
353+
337354
protected:
338355
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
339356
virtual ~PatternRewriter();

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,12 @@ class ConversionPatternRewriter final : public PatternRewriter {
379379
/// PatternRewriter hook for updating the root operation in-place.
380380
void cancelRootUpdate(Operation *op) override;
381381

382+
/// PatternRewriter hook for notifying match failure reasons.
383+
LogicalResult
384+
notifyMatchFailure(Operation *op,
385+
function_ref<void(Diagnostic &)> reasonCallback) override;
386+
using PatternRewriter::notifyMatchFailure;
387+
382388
/// Return a reference to the internal implementation.
383389
detail::ConversionPatternRewriterImpl &getImpl();
384390

mlir/lib/Transforms/DialectConversion.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,17 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
989989
rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
990990
}
991991

992+
/// PatternRewriter hook for notifying match failure reasons.
993+
LogicalResult ConversionPatternRewriter::notifyMatchFailure(
994+
Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
995+
LLVM_DEBUG({
996+
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
997+
reasonCallback(diag);
998+
impl->logger.startLine() << "** Failure : " << diag.str() << "\n";
999+
});
1000+
return failure();
1001+
}
1002+
9921003
/// Return a reference to the internal implementation.
9931004
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
9941005
return *impl;

mlir/test/lib/TestDialect/TestPatterns.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
272272
ConversionPatternRewriter &rewriter) const final {
273273
// If the type is F32, change the type to F64.
274274
if (!Type(*op->result_type_begin()).isF32())
275-
return matchFailure();
275+
return rewriter.notifyMatchFailure(op, "expected single f32 operand");
276276
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
277277
return matchSuccess();
278278
}

0 commit comments

Comments
 (0)