Skip to content

Commit 5d4603a

Browse files
committed
[mlir][tensor] Support more cases in MergeConsecutiveExtractSlice
This commit adds utility functions to perform general merging of OffsetSizeAndStrideOpInterface by supporting producer rank reducing and non-unit strides. With it we can extend MergeConsecutiveExtractSlice to support more cases. Co-authored-by: Mahesh Ravishankar <[email protected]> Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D134294
1 parent 06010fd commit 5d4603a

File tree

3 files changed

+161
-42
lines changed

3 files changed

+161
-42
lines changed

mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,37 @@
1414
namespace mlir {
1515
namespace tensor {
1616

17+
/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
18+
/// when combining a producer slice **into** a consumer slice.
19+
///
20+
/// This function performs the following computation:
21+
/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
22+
/// - Combined sizes = consumer_sizes
23+
/// - Combined strides = producer_strides * consumer_strides
24+
LogicalResult
25+
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
26+
ArrayRef<OpFoldResult> producerOffsets,
27+
ArrayRef<OpFoldResult> producerSizes,
28+
ArrayRef<OpFoldResult> producerStrides,
29+
const llvm::SmallBitVector &droppedProducerDims,
30+
ArrayRef<OpFoldResult> consumerOffsets,
31+
ArrayRef<OpFoldResult> consumerSizes,
32+
ArrayRef<OpFoldResult> consumerStrides,
33+
SmallVector<OpFoldResult> &combinedOffsets,
34+
SmallVector<OpFoldResult> &combinedSizes,
35+
SmallVector<OpFoldResult> &combinedStrides);
36+
37+
/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
38+
/// when combining a `producer` slice op **into** a `consumer` slice op.
39+
LogicalResult
40+
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
41+
OffsetSizeAndStrideOpInterface producer,
42+
OffsetSizeAndStrideOpInterface consumer,
43+
const llvm::SmallBitVector &droppedProducerDims,
44+
SmallVector<OpFoldResult> &combinedOffsets,
45+
SmallVector<OpFoldResult> &combinedSizes,
46+
SmallVector<OpFoldResult> &combinedStrides);
47+
1748
//===----------------------------------------------------------------------===//
1849
// Extract slice from `tensor.collapse_shape`
1950
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp

Lines changed: 102 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10-
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
1110
#include "mlir/Dialect/Tensor/IR/Tensor.h"
11+
#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
1212
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1313
#include "mlir/IR/BuiltinTypes.h"
1414
#include "mlir/IR/OpDefinition.h"
@@ -17,29 +17,101 @@
1717
using namespace mlir;
1818
using namespace mlir::tensor;
1919

20-
/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
21-
/// returns the results.
22-
static SmallVector<OpFoldResult> mergeOffsets(Location loc,
23-
ArrayRef<OpFoldResult> offsets1,
24-
ArrayRef<OpFoldResult> offsets2,
25-
OpBuilder &builder) {
26-
SmallVector<OpFoldResult> foldedOffsets;
27-
assert(offsets1.size() == offsets2.size());
28-
foldedOffsets.reserve(offsets1.size());
29-
30-
AffineExpr dim1, dim2;
31-
bindDims(builder.getContext(), dim1, dim2);
32-
33-
for (const auto &pair : llvm::zip(offsets1, offsets2)) {
34-
auto offset0 =
35-
getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair));
36-
auto offset1 =
37-
getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair));
38-
auto foldedOffset =
39-
makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1});
40-
foldedOffsets.push_back(foldedOffset.getResult());
20+
/// Creates AffineExpr from `ofr`: if the OpFoldResult is a Value, creates a
21+
/// AffineSymbolExpr and appends it to `symbols`; otherwise creates a
22+
/// AffineConstantExpr.
23+
static AffineExpr getAffineExpr(OpFoldResult ofr,
24+
SmallVector<OpFoldResult> &symbols) {
25+
if (auto attr = ofr.dyn_cast<Attribute>()) {
26+
return getAffineConstantExpr(attr.cast<IntegerAttr>().getInt(),
27+
attr.getContext());
4128
}
42-
return foldedOffsets;
29+
Value v = ofr.get<Value>();
30+
AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext());
31+
symbols.push_back(v);
32+
return expr;
33+
}
34+
35+
/// Builds the AffineExpr incrementally for arithmetic operations.
36+
static AffineExpr add(AffineExpr expr, OpFoldResult ofr,
37+
SmallVector<OpFoldResult> &symbols) {
38+
return expr + getAffineExpr(ofr, symbols);
39+
}
40+
static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs,
41+
SmallVector<OpFoldResult> &symbols) {
42+
return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols);
43+
}
44+
45+
/// Converts an AffineExpr to OpFoldResult by generating an `affine.apply`
46+
/// op and fold it.
47+
static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc,
48+
AffineExpr expr,
49+
SmallVector<OpFoldResult> &symbols) {
50+
AffineMap m = AffineMap::get(0, symbols.size(), expr);
51+
return makeComposedFoldedAffineApply(builder, loc, m, symbols);
52+
}
53+
54+
LogicalResult tensor::mergeOffsetsSizesAndStrides(
55+
OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets,
56+
ArrayRef<OpFoldResult> producerSizes,
57+
ArrayRef<OpFoldResult> producerStrides,
58+
const llvm::SmallBitVector &droppedProducerDims,
59+
ArrayRef<OpFoldResult> consumerOffsets,
60+
ArrayRef<OpFoldResult> consumerSizes,
61+
ArrayRef<OpFoldResult> consumerStrides,
62+
SmallVector<OpFoldResult> &combinedOffsets,
63+
SmallVector<OpFoldResult> &combinedSizes,
64+
SmallVector<OpFoldResult> &combinedStrides) {
65+
combinedOffsets.resize(producerOffsets.size());
66+
combinedSizes.resize(producerOffsets.size());
67+
combinedStrides.resize(producerOffsets.size());
68+
unsigned consumerPos = 0;
69+
for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) {
70+
if (droppedProducerDims.test(i)) {
71+
// For dropped dims, get the values from the producer.
72+
combinedOffsets[i] = producerOffsets[i];
73+
combinedSizes[i] = producerSizes[i];
74+
combinedStrides[i] = producerStrides[i];
75+
continue;
76+
}
77+
SmallVector<OpFoldResult> offsetSymbols, strideSymbols;
78+
// The combined offset is computed as
79+
// producer_offset + consumer_offset * producer_strides.
80+
combinedOffsets[i] =
81+
getOpFoldResult(builder, loc,
82+
add(mul(consumerOffsets[consumerPos],
83+
producerStrides[i], offsetSymbols),
84+
producerOffsets[i], offsetSymbols),
85+
offsetSymbols);
86+
combinedSizes[i] = consumerSizes[consumerPos];
87+
// The combined stride is computed as
88+
// consumer_stride * producer_stride.
89+
combinedStrides[i] = getOpFoldResult(
90+
builder, loc,
91+
mul(consumerStrides[consumerPos], producerStrides[i], strideSymbols),
92+
strideSymbols);
93+
consumerPos++;
94+
}
95+
return success();
96+
}
97+
98+
LogicalResult tensor::mergeOffsetsSizesAndStrides(
99+
OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer,
100+
OffsetSizeAndStrideOpInterface consumer,
101+
const llvm::SmallBitVector &droppedProducerDims,
102+
SmallVector<OpFoldResult> &combinedOffsets,
103+
SmallVector<OpFoldResult> &combinedSizes,
104+
SmallVector<OpFoldResult> &combinedStrides) {
105+
SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
106+
SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes();
107+
SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
108+
SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
109+
SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes();
110+
SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
111+
return tensor::mergeOffsetsSizesAndStrides(
112+
builder, loc, producerOffsets, producerSizes, producerStrides,
113+
droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
114+
combinedOffsets, combinedSizes, combinedStrides);
43115
}
44116

45117
namespace {
@@ -53,24 +125,15 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
53125
if (!prevOp)
54126
return failure();
55127

56-
if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
128+
SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
129+
if (failed(mergeOffsetsSizesAndStrides(rewriter, nextOp.getLoc(), prevOp,
130+
nextOp, prevOp.getDroppedDims(),
131+
newOffsets, newSizes, newStrides)))
57132
return failure();
58133

59-
auto prevResultType = prevOp.getType().cast<ShapedType>();
60-
if (prevOp.getSourceType().getRank() != prevResultType.getRank())
61-
return rewriter.notifyMatchFailure(
62-
prevOp, "rank-reducing producder case unimplemented");
63-
64-
Location loc = nextOp.getLoc();
65-
66-
SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets();
67-
SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets();
68-
SmallVector<OpFoldResult> foldedOffsets =
69-
mergeOffsets(loc, prevOffsets, nextOffsets, rewriter);
70-
71-
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
72-
nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets,
73-
nextOp.getMixedSizes(), nextOp.getMixedStrides());
134+
rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
135+
prevOp.getSource(), newOffsets,
136+
newSizes, newStrides);
74137
return success();
75138
}
76139
};

mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ func.func @extract_slice_same_rank(
99

1010
// CHECK-LABEL: func.func @extract_slice_same_rank
1111
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
12-
// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]]
12+
// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
1313
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
1414
// CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32>
1515

16+
// -----
17+
1618
func.func @extract_slice_rank_reducing_consumer(
1719
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
1820
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
@@ -23,15 +25,36 @@ func.func @extract_slice_rank_reducing_consumer(
2325
// CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
2426
// CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
2527

28+
// -----
29+
2630
func.func @extract_slice_rank_reducing_producer(
2731
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
2832
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
2933
%1 = tensor.extract_slice %0[7, %offset1] [8, %size1] [1, 1] : tensor<128x?xf32> to tensor<8x?xf32>
3034
return %1: tensor<8x?xf32>
3135
}
3236

33-
// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
34-
// CHECK-COUNT-2: tensor.extract_slice
37+
// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
38+
// CHECK-SAME: (%[[SRC:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
39+
// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
40+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][0, 8, 2, %[[OFFSET]]] [1, 8, 1, %[[SIZE1]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<8x?xf32>
41+
// CHECK: return %[[EXTRACT]] : tensor<8x?xf32>
42+
43+
// -----
44+
45+
func.func @extract_slice_non_one_stride(
46+
%src: tensor<?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index, %stride0: index, %stride1: index) -> tensor<?xf32> {
47+
%0 = tensor.extract_slice %src[%offset0] [%size0] [%stride0] : tensor<?xf32> to tensor<?xf32>
48+
%1 = tensor.extract_slice %0[%offset1] [%size1] [%stride1] : tensor<?xf32> to tensor<?xf32>
49+
return %1: tensor<?xf32>
50+
}
51+
52+
// CHECK-LABEL: func.func @extract_slice_non_one_stride
53+
// CHECK-SAME: (%[[SRC:.+]]: tensor<?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index, %[[STRIDE0:.+]]: index, %[[STRIDE1:.+]]: index)
54+
// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>()[%[[OFFSET1]], %[[STRIDE0]], %[[OFFSET0]]]
55+
// CHECK: %[[STRIDE:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%[[STRIDE1]], %[[STRIDE0]]]
56+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][%[[OFFSET]]] [%[[SIZE1]]] [%[[STRIDE]]] : tensor<?xf32> to tensor<?xf32>
57+
// CHECK: return %[[EXTRACT]] : tensor<?xf32>
3558

3659
// -----
3760

@@ -47,6 +70,8 @@ func.func @insert_slice_rank_reducing(
4770
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1]
4871
// CHECK: return %[[INSERT]]
4972

73+
// -----
74+
5075
func.func @insert_slice_rank_reducing_dynamic_shape(
5176
%dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> {
5277
%0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32>

0 commit comments

Comments
 (0)