Skip to content

Commit 6874726

Browse files
committed
[PatternMatching] Add convenience insert method to OwningRewritePatternList. NFC.
This allows adding a C function pointer as a matchAndRewrite style pattern, which is a very common case. This adopts it in ExpandTanh to show how it reduces a level of nesting. We could allow C++ lambdas here, but that doesn't work as well with type inference in the common case. Instead of: patterns.insert(convertTanhOp); you need to specify: patterns.insert<math::TanhOp>(convertTanhOp); which is boilerplate'y. Capturing state like this is very uncommon, so we choose to require clients to define their own structs and use the non-convenience method when they need to do so. Differential Revision: https://reviews.llvm.org/D99039
1 parent f21704e commit 6874726

File tree

3 files changed

+66
-41
lines changed

3 files changed

+66
-41
lines changed

mlir/docs/Tutorials/QuickstartRewrites.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ struct ConvertTFLeakyRelu : public RewritePattern {
189189
: RewritePattern("tf.LeakyRelu", 1, context) {}
190190

191191
LogicalResult matchAndRewrite(Operation *op,
192-
PatternRewriter &rewriter) const override {
192+
PatternRewriter &rewriter) const override {
193193
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
194194
op, op->getResult(0).getType(), op->getOperand(0),
195195
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
@@ -202,6 +202,19 @@ In the C++ rewrite the static benefit of the rewrite pattern is specified at
202202
construction. While in the pattern generator a simple heuristic is currently
203203
employed based around the number of ops matched and replaced.
204204
205+
In the case where you have a registered op and want to use a benefit of 1, you
206+
can even define the pattern as a C function:
207+
208+
```c++
209+
static LogicalResult
210+
convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
211+
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
212+
op, op->getResult(0).getType(), op->getOperand(0),
213+
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
214+
return success();
215+
}
216+
```
217+
205218
The above rule did not capture the matching operands/attributes, but in general
206219
the `match` function in a multi-step rewrite may populate and return a
207220
`PatternState` (or class derived from one) to pass information extracted during

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,27 @@ class OwningRewritePatternList {
790790
return *this;
791791
}
792792

793+
// Add a matchAndRewrite style pattern represented as a C function pointer.
794+
template <typename OpType>
795+
OwningRewritePatternList &
796+
insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
797+
struct FnPattern final : public OpRewritePattern<OpType> {
798+
FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
799+
MLIRContext *context)
800+
: OpRewritePattern<OpType>(context), implFn(implFn) {}
801+
802+
LogicalResult matchAndRewrite(OpType op,
803+
PatternRewriter &rewriter) const override {
804+
return implFn(op, rewriter);
805+
}
806+
807+
private:
808+
LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
809+
};
810+
insert(std::make_unique<FnPattern>(std::move(implFn), getContext()));
811+
return *this;
812+
}
813+
793814
private:
794815
/// Add an instance of the pattern type 'T'. Return a reference to `this` for
795816
/// chaining insertions.

mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,51 +15,42 @@
1515
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1616
#include "mlir/IR/Builders.h"
1717
#include "mlir/Transforms/DialectConversion.h"
18-
1918
using namespace mlir;
2019

21-
namespace {
2220
/// Expands tanh op into
2321
/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
2422
/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
25-
struct TanhOpConverter : public OpRewritePattern<math::TanhOp> {
26-
public:
27-
using OpRewritePattern::OpRewritePattern;
28-
29-
LogicalResult matchAndRewrite(math::TanhOp op,
30-
PatternRewriter &rewriter) const final {
31-
auto floatType = op.operand().getType();
32-
Location loc = op.getLoc();
33-
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
34-
auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
35-
Value one = rewriter.create<ConstantOp>(loc, floatOne);
36-
Value two = rewriter.create<ConstantOp>(loc, floatTwo);
37-
Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);
38-
39-
// Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
40-
Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
41-
Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
42-
Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
43-
Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
44-
Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);
45-
46-
// Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
47-
exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
48-
dividend = rewriter.create<SubFOp>(loc, exp2x, one);
49-
divisor = rewriter.create<AddFOp>(loc, exp2x, one);
50-
Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);
51-
52-
// tanh(x) = x >= 0 ? positiveRes : negativeRes
53-
auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
54-
Value zero = rewriter.create<ConstantOp>(loc, floatZero);
55-
Value cmpRes =
56-
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
57-
rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
58-
return success();
59-
}
60-
};
61-
} // namespace
23+
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
24+
auto floatType = op.operand().getType();
25+
Location loc = op.getLoc();
26+
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
27+
auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
28+
Value one = rewriter.create<ConstantOp>(loc, floatOne);
29+
Value two = rewriter.create<ConstantOp>(loc, floatTwo);
30+
Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);
31+
32+
// Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
33+
Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
34+
Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
35+
Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
36+
Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
37+
Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);
38+
39+
// Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
40+
exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
41+
dividend = rewriter.create<SubFOp>(loc, exp2x, one);
42+
divisor = rewriter.create<AddFOp>(loc, exp2x, one);
43+
Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);
44+
45+
// tanh(x) = x >= 0 ? positiveRes : negativeRes
46+
auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
47+
Value zero = rewriter.create<ConstantOp>(loc, floatZero);
48+
Value cmpRes =
49+
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
50+
rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
51+
return success();
52+
}
6253

6354
void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns) {
64-
patterns.insert<TanhOpConverter>(patterns.getContext());
55+
patterns.insert(convertTanhOp);
6556
}

0 commit comments

Comments
 (0)