Skip to content

Commit bb4c53b

Browse files
committed
[mlir][tensor] Merge consecutive insert_slice/extract_slice ops
Consecutive tensor.insert_slice/tensor.extract_slice can be created for the case like tiling convolution and then downsizing 2-D convolutions into 1-D ones. It hinders further transformations. So adding these patterns to clean it up. Given that bufferization is sensitive and have requirements over the IR structure (see https://reviews.llvm.org/D132666), these patterns are put in Transforms/ with separate entry points for explicit collection. Reviewed By: ThomasRaoux, mravishankar Differential Revision: https://reviews.llvm.org/D133871
1 parent 42bcb35 commit bb4c53b

File tree

5 files changed

+197
-0
lines changed

5 files changed

+197
-0
lines changed

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns,
2929
FailureOr<Value> replaceExtractSliceWithTiledProducer(
3030
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
3131

32+
/// Collects patterns to merge consecutive tensor.insert_slice/extract_slice
33+
/// into one. These patterns are in in this separate entry point because the
34+
/// bufferization is sensitive over IR structure, particularly those
35+
/// tensor.extract_slice and tensor.insert_slice ops for creating the slices.
36+
void populateMergeConsecutiveInsertExtractSlicePatterns(
37+
RewritePatternSet &patterns);
38+
3239
} // namespace tensor
3340
} // namespace mlir
3441

mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
22
BufferizableOpInterfaceImpl.cpp
33
Bufferize.cpp
44
ExtractSliceFromReshape.cpp
5+
MergeConsecutiveInsertExtractSlicePatterns.cpp
56
SplitPadding.cpp
67
SwapExtractSliceWithProducer.cpp
78

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
//===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
12+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
13+
#include "mlir/IR/BuiltinTypes.h"
14+
#include "mlir/IR/OpDefinition.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
17+
using namespace mlir;
18+
using namespace mlir::tensor;
19+
20+
/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
21+
/// returns the results.
22+
static SmallVector<OpFoldResult> mergeOffsets(Location loc,
23+
ArrayRef<OpFoldResult> offsets1,
24+
ArrayRef<OpFoldResult> offsets2,
25+
OpBuilder &builder) {
26+
SmallVector<OpFoldResult> foldedOffsets;
27+
assert(offsets1.size() == offsets2.size());
28+
foldedOffsets.reserve(offsets1.size());
29+
30+
AffineExpr dim1, dim2;
31+
bindDims(builder.getContext(), dim1, dim2);
32+
33+
for (const auto &pair : llvm::zip(offsets1, offsets2)) {
34+
auto offset0 =
35+
getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair));
36+
auto offset1 =
37+
getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair));
38+
auto foldedOffset =
39+
makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1});
40+
foldedOffsets.push_back(foldedOffset.getResult());
41+
}
42+
return foldedOffsets;
43+
}
44+
45+
namespace {
46+
/// Merges consecutive tensor.extract_slice ops into one.
47+
struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
48+
using OpRewritePattern::OpRewritePattern;
49+
50+
LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
51+
PatternRewriter &rewriter) const override {
52+
auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
53+
if (!prevOp)
54+
return failure();
55+
56+
if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
57+
return failure();
58+
59+
auto prevResultType = prevOp.getType().cast<ShapedType>();
60+
if (prevOp.getSourceType().getRank() != prevResultType.getRank())
61+
return rewriter.notifyMatchFailure(
62+
prevOp, "rank-reducing producder case unimplemented");
63+
64+
Location loc = nextOp.getLoc();
65+
66+
SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets();
67+
SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets();
68+
SmallVector<OpFoldResult> foldedOffsets =
69+
mergeOffsets(loc, prevOffsets, nextOffsets, rewriter);
70+
71+
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
72+
nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets,
73+
nextOp.getMixedSizes(), nextOp.getMixedStrides());
74+
return success();
75+
}
76+
};
77+
78+
/// Merges consecutive tensor.insert_slice ops into one.
79+
struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
80+
using OpRewritePattern::OpRewritePattern;
81+
82+
LogicalResult matchAndRewrite(InsertSliceOp nextOp,
83+
PatternRewriter &rewriter) const override {
84+
auto prevOp = nextOp.getSource().getDefiningOp<InsertSliceOp>();
85+
if (!prevOp)
86+
return failure();
87+
88+
if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
89+
return failure();
90+
91+
// The first insert_slice op should be rank reducing to make sure we cover
92+
// the full source tensor to be inserted in the second insert_slice op.
93+
SliceVerificationResult result =
94+
isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
95+
if (result != SliceVerificationResult::Success)
96+
return failure();
97+
98+
// Dynamic dimensions can pass rank reducing check in the above, e.g,
99+
// inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
100+
// the dynamic size covers the full tensor.
101+
if (!prevOp.getSourceType().hasStaticShape() ||
102+
!prevOp.getDestType().hasStaticShape())
103+
return failure();
104+
105+
rewriter.replaceOpWithNewOp<InsertSliceOp>(
106+
nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
107+
nextOp.getMixedSizes(), nextOp.getMixedStrides());
108+
return success();
109+
}
110+
};
111+
} // namespace
112+
113+
void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
114+
RewritePatternSet &patterns) {
115+
patterns.add<MergeConsecutiveExtractSlice, MergeConsecutiveInsertSlice>(
116+
patterns.getContext());
117+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-consecutive-insert-extract-slice -canonicalize -mlir-print-local-scope %s | FileCheck %s
2+
3+
func.func @extract_slice_same_rank(
4+
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x16x32x?xf32> {
5+
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
6+
%1 = tensor.extract_slice %0[7, 8, 9, %offset1] [8, 16, 32, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<8x16x32x?xf32>
7+
return %1: tensor<8x16x32x?xf32>
8+
}
9+
10+
// CHECK-LABEL: func.func @extract_slice_same_rank
11+
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
12+
// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]]
13+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
14+
// CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32>
15+
16+
func.func @extract_slice_rank_reducing_consumer(
17+
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
18+
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
19+
%1 = tensor.extract_slice %0[7, 8, 9, %offset1] [1, 16, 1, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<16x?xf32>
20+
return %1: tensor<16x?xf32>
21+
}
22+
23+
// CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
24+
// CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
25+
26+
func.func @extract_slice_rank_reducing_producer(
27+
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
28+
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
29+
%1 = tensor.extract_slice %0[7, %offset1] [8, %size1] [1, 1] : tensor<128x?xf32> to tensor<8x?xf32>
30+
return %1: tensor<8x?xf32>
31+
}
32+
33+
// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
34+
// CHECK-COUNT-2: tensor.extract_slice
35+
36+
// -----
37+
38+
func.func @insert_slice_rank_reducing(
39+
%dst: tensor<128x128x128x128xf32>, %mid: tensor<1x16x1xf32>, %src: tensor<16xf32>, %offset: index) -> tensor<128x128x128x128xf32> {
40+
%0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, 16, 1] [1, 1, 1] : tensor<16xf32> into tensor<1x16x1xf32>
41+
%1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, 16, 1] [1, 1, 1, 1] : tensor<1x16x1xf32> into tensor<128x128x128x128xf32>
42+
return %1: tensor<128x128x128x128xf32>
43+
}
44+
45+
// CHECK-LABEL: func.func @insert_slice_rank_reducing
46+
// CHECK-SAME: (%[[DST:.+]]: tensor<128x128x128x128xf32>, %{{.+}}: tensor<1x16x1xf32>, %[[SRC:.+]]: tensor<16xf32>, %[[IDX:.+]]: index)
47+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1]
48+
// CHECK: return %[[INSERT]]
49+
50+
func.func @insert_slice_rank_reducing_dynamic_shape(
51+
%dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> {
52+
%0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32>
53+
%1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<1x?x1xf32> into tensor<128x128x128x128xf32>
54+
return %1: tensor<128x128x128x128xf32>
55+
}
56+
57+
// CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape
58+
// CHECK-COUNT-2: tensor.insert_slice

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ struct TestTensorTransforms
5353
llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
5454
llvm::cl::init(false)};
5555

56+
Option<bool> testFoldConsecutiveInsertExtractSlice{
57+
*this, "test-fold-consecutive-insert-extract-slice",
58+
llvm::cl::desc(
59+
"Test folding consecutive tensor.insert_slice/tensor.extract_slice"),
60+
llvm::cl::init(false)};
61+
5662
Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
5763
*this, "test-rewrite-extract-slice-from-collapse-shape",
5864
llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
@@ -90,6 +96,12 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
9096
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
9197
}
9298

99+
static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
100+
RewritePatternSet patterns(rootOp->getContext());
101+
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
102+
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
103+
}
104+
93105
namespace {
94106
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
95107
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -233,6 +245,8 @@ void TestTensorTransforms::runOnOperation() {
233245
applySplitPaddingPatterns(rootOp);
234246
if (testFoldConstantExtractSlice)
235247
applyFoldConstantExtractSlicePatterns(rootOp);
248+
if (testFoldConsecutiveInsertExtractSlice)
249+
applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
236250
if (testRewriteExtractSliceWithTiledCollapseShape) {
237251
if (failed(
238252
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))

0 commit comments

Comments
 (0)