Skip to content

Commit d12d05a

Browse files
author
Nicolas Vasilache
committed
[mlir][Linalg] Introduce a helper function for staged pattern application
Summary: This revision introduces a helper function to allow applying rewrite patterns, interleaved with more global transformations, in a staged fashion: 1. the first stage consists of an OwningRewritePatternList. The RewritePattern in this list are applied once and in order. 2. the second stage consists of a single OwningRewritePattern that is applied greedily until convergence. 3. the third stage consists of applying a lambda, generally used for non-local transformation effects. This allows creating custom fused transformations where patterns can be ordered and applied at a finer granularity than a sequence of traditional compiler passes. A test that exercises these behaviors is added. Differential Revision: https://reviews.llvm.org/D79518
1 parent cd7cb1f commit d12d05a

File tree

5 files changed

+154
-4
lines changed

5 files changed

+154
-4
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,23 @@ struct LinalgLoweringPattern : public RewritePattern {
367367
LinalgLoweringType loweringType;
368368
};
369369

370+
//===----------------------------------------------------------------------===//
371+
// Support for staged pattern application.
372+
//===----------------------------------------------------------------------===//
373+
/// Helper function to allow applying rewrite patterns, interleaved with more
374+
/// global transformations, in a staged fashion:
375+
/// 1. the first stage consists of a list of OwningRewritePatternList. Each
376+
/// OwningRewritePatternList in this list is applied once, in order.
377+
/// 2. the second stage consists of a single OwningRewritePattern that is
378+
/// applied greedily until convergence.
379+
/// 3. the third stage consists of applying a lambda, generally used for
380+
/// non-local transformation effects. This allows creating custom fused
381+
/// transformations where patterns can be ordered and applied at a finer
382+
/// granularity than a sequence of traditional compiler passes.
383+
LogicalResult applyStagedPatterns(
384+
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
385+
const OwningRewritePatternList &stage2Patterns,
386+
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
370387
} // namespace linalg
371388
} // namespace mlir
372389

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,15 @@ class OwningRewritePatternList {
388388
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
389389

390390
public:
391+
OwningRewritePatternList() = default;
392+
393+
/// Construct a OwningRewritePatternList populated with the pattern `t` of
394+
/// type `T`.
395+
template <typename T>
396+
OwningRewritePatternList(T &&t) {
397+
patterns.emplace_back(std::make_unique<T>(t));
398+
}
399+
391400
PatternListT::iterator begin() { return patterns.begin(); }
392401
PatternListT::iterator end() { return patterns.end(); }
393402
PatternListT::const_iterator begin() const { return patterns.begin(); }
@@ -399,19 +408,21 @@ class OwningRewritePatternList {
399408
//===--------------------------------------------------------------------===//
400409

401410
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
402-
/// the given arguments.
411+
/// the given arguments. Return a reference to `this` for chaining insertions.
403412
/// Note: ConstructorArg is necessary here to separate the two variadic lists.
404413
template <typename... Ts, typename ConstructorArg,
405414
typename... ConstructorArgs,
406415
typename = std::enable_if_t<sizeof...(Ts) != 0>>
407-
void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
416+
OwningRewritePatternList &insert(ConstructorArg &&arg,
417+
ConstructorArgs &&... args) {
408418
// The following expands a call to emplace_back for each of the pattern
409419
// types 'Ts'. This magic is necessary due to a limitation in the places
410420
// that a parameter pack can be expanded in c++11.
411421
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
412422
using dummy = int[];
413423
(void)dummy{
414424
0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
425+
return *this;
415426
}
416427

417428
private:

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,24 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
198198
rewriter.eraseOp(op);
199199
return success();
200200
}
201+
202+
LogicalResult mlir::linalg::applyStagedPatterns(
203+
Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
204+
const OwningRewritePatternList &stage2Patterns,
205+
llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
206+
for (const auto &patterns : stage1Patterns) {
207+
if (!applyPatternsAndFoldGreedily(op, patterns)) {
208+
llvm::dbgs() << "Underlying first stage rewrite did not converge";
209+
return failure();
210+
}
211+
if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
212+
llvm::dbgs() << "Underlying second stage rewrite did not converge";
213+
return failure();
214+
}
215+
if (stage3Lambda) {
216+
if (failed(stage3Lambda(op)))
217+
return failure();
218+
}
219+
}
220+
return success();
221+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// TODO: this needs a fix to land before being reactivated.
2+
// RUN: ls
3+
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
4+
// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
5+
6+
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
7+
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
8+
%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
9+
linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} :
10+
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
11+
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
12+
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
13+
return
14+
}
15+
16+
// CHECK-LABEL:func @matmul
17+
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
18+
// CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
19+
//
20+
// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32>
21+
// CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
22+
//
23+
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32>
24+
// CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
25+
//
26+
// CHECK: linalg.copy
27+
// CHECK: linalg.copy
28+
// CHECK: linalg.copy
29+
//
30+
// CHECK: vector.contract
31+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
32+
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
33+
//
34+
// CHECK: linalg.copy

mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ struct TestLinalgTransforms
3333
Option<bool> testPatterns{*this, "test-patterns",
3434
llvm::cl::desc("Test a mixed set of patterns"),
3535
llvm::cl::init(false)};
36+
Option<bool> testMatmulToVectorPatterns1dTiling{
37+
*this, "test-matmul-to-vector-patterns-tile-1d",
38+
llvm::cl::desc(
39+
"Test a fused pass that applies patterns from matmul to vectors via "
40+
"1-d tiling"),
41+
llvm::cl::init(false)};
42+
Option<bool> testMatmulToVectorPatterns2dTiling{
43+
*this, "test-matmul-to-vector-patterns-tile-2d",
44+
llvm::cl::desc(
45+
"Test a fused pass that applies patterns from matmul to vectors via "
46+
"2-d tiling"),
47+
llvm::cl::init(false)};
3648
};
3749
} // end anonymous namespace
3850

@@ -137,10 +149,65 @@ static void applyPatterns(FuncOp funcOp) {
137149
});
138150
}
139151

152+
OwningRewritePatternList
153+
getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) {
154+
OwningRewritePatternList patterns;
155+
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
156+
AffineMinOp::getCanonicalizationPatterns(patterns, context);
157+
AffineMaxOp::getCanonicalizationPatterns(patterns, context);
158+
AllocOp::getCanonicalizationPatterns(patterns, context);
159+
SubViewOp::getCanonicalizationPatterns(patterns, context);
160+
ViewOp::getCanonicalizationPatterns(patterns, context);
161+
MatmulOp::getCanonicalizationPatterns(patterns, context);
162+
return patterns;
163+
}
164+
165+
void fillL1TilingAndMatmulToVectorPatterns(
166+
MLIRContext *context, StringRef startMarker,
167+
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
168+
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
169+
context,
170+
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
171+
LinalgMarker({startMarker}, "L1")));
172+
173+
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
174+
context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC")));
175+
176+
patternsVector.emplace_back(
177+
LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
178+
patternsVector.back()
179+
.insert<LinalgVectorizationPattern<FillOp>,
180+
LinalgVectorizationPattern<CopyOp>>(context);
181+
}
182+
140183
/// Apply transformations specified as patterns.
141184
void TestLinalgTransforms::runOnFunction() {
142-
if (testPatterns)
143-
return applyPatterns(getFunction());
185+
if (testPatterns) {
186+
applyPatterns(getFunction());
187+
} else {
188+
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
189+
if (testMatmulToVectorPatterns1dTiling) {
190+
fillL1TilingAndMatmulToVectorPatterns(&getContext(), "START",
191+
stage1Patterns);
192+
} else if (testMatmulToVectorPatterns2dTiling) {
193+
stage1Patterns.emplace_back(
194+
LinalgTilingPattern<MatmulOp>(&getContext(),
195+
LinalgTilingOptions()
196+
.setTileSizes({768, 264, 768})
197+
.setInterchange({1, 2, 0}),
198+
LinalgMarker({"START"}, "L2")));
199+
fillL1TilingAndMatmulToVectorPatterns(&getContext(), "L2",
200+
stage1Patterns);
201+
}
202+
OwningRewritePatternList stage2Patterns =
203+
getMatmulToVectorCanonicalizationPatterns(&getContext());
204+
applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns);
205+
}
206+
207+
// Drop the marker.
208+
getFunction().walk([](LinalgOp op) {
209+
op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
210+
});
144211
}
145212

146213
namespace mlir {

0 commit comments

Comments
 (0)