Skip to content

Commit e0a6df5

Browse files
committed
Revert "[mlir][tensor] Support more cases in MergeConsecutiveExtractSlice"
This reverts commit 5d4603a. The Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir test is failing when built with GCC
1 parent b052eea commit e0a6df5

File tree

3 files changed

+42
-161
lines changed

3 files changed

+42
-161
lines changed

mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,6 @@
1414
namespace mlir {
1515
namespace tensor {
1616

17-
/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
18-
/// when combining a producer slice **into** a consumer slice.
19-
///
20-
/// This function performs the following computation:
21-
/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
22-
/// - Combined sizes = consumer_sizes
23-
/// - Combined strides = producer_strides * consumer_strides
24-
LogicalResult
25-
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
26-
ArrayRef<OpFoldResult> producerOffsets,
27-
ArrayRef<OpFoldResult> producerSizes,
28-
ArrayRef<OpFoldResult> producerStrides,
29-
const llvm::SmallBitVector &droppedProducerDims,
30-
ArrayRef<OpFoldResult> consumerOffsets,
31-
ArrayRef<OpFoldResult> consumerSizes,
32-
ArrayRef<OpFoldResult> consumerStrides,
33-
SmallVector<OpFoldResult> &combinedOffsets,
34-
SmallVector<OpFoldResult> &combinedSizes,
35-
SmallVector<OpFoldResult> &combinedStrides);
36-
37-
/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
38-
/// when combining a `producer` slice op **into** a `consumer` slice op.
39-
LogicalResult
40-
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
41-
OffsetSizeAndStrideOpInterface producer,
42-
OffsetSizeAndStrideOpInterface consumer,
43-
const llvm::SmallBitVector &droppedProducerDims,
44-
SmallVector<OpFoldResult> &combinedOffsets,
45-
SmallVector<OpFoldResult> &combinedSizes,
46-
SmallVector<OpFoldResult> &combinedStrides);
47-
4817
//===----------------------------------------------------------------------===//
4918
// Extract slice from `tensor.collapse_shape`
5019
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp

Lines changed: 39 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
1011
#include "mlir/Dialect/Tensor/IR/Tensor.h"
11-
#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
1212
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1313
#include "mlir/IR/BuiltinTypes.h"
1414
#include "mlir/IR/OpDefinition.h"
@@ -17,101 +17,29 @@
1717
using namespace mlir;
1818
using namespace mlir::tensor;
1919

20-
/// Creates AffineExpr from `ofr`: if the OpFoldResult is a Value, creates a
21-
/// AffineSymbolExpr and appends it to `symbols`; otherwise creates a
22-
/// AffineConstantExpr.
23-
static AffineExpr getAffineExpr(OpFoldResult ofr,
24-
SmallVector<OpFoldResult> &symbols) {
25-
if (auto attr = ofr.dyn_cast<Attribute>()) {
26-
return getAffineConstantExpr(attr.cast<IntegerAttr>().getInt(),
27-
attr.getContext());
20+
/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
21+
/// returns the results.
22+
static SmallVector<OpFoldResult> mergeOffsets(Location loc,
23+
ArrayRef<OpFoldResult> offsets1,
24+
ArrayRef<OpFoldResult> offsets2,
25+
OpBuilder &builder) {
26+
SmallVector<OpFoldResult> foldedOffsets;
27+
assert(offsets1.size() == offsets2.size());
28+
foldedOffsets.reserve(offsets1.size());
29+
30+
AffineExpr dim1, dim2;
31+
bindDims(builder.getContext(), dim1, dim2);
32+
33+
for (const auto &pair : llvm::zip(offsets1, offsets2)) {
34+
auto offset0 =
35+
getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair));
36+
auto offset1 =
37+
getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair));
38+
auto foldedOffset =
39+
makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1});
40+
foldedOffsets.push_back(foldedOffset.getResult());
2841
}
29-
Value v = ofr.get<Value>();
30-
AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext());
31-
symbols.push_back(v);
32-
return expr;
33-
}
34-
35-
/// Builds the AffineExpr incrementally for arithmetic operations.
36-
static AffineExpr add(AffineExpr expr, OpFoldResult ofr,
37-
SmallVector<OpFoldResult> &symbols) {
38-
return expr + getAffineExpr(ofr, symbols);
39-
}
40-
static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs,
41-
SmallVector<OpFoldResult> &symbols) {
42-
return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols);
43-
}
44-
45-
/// Converts an AffineExpr to OpFoldResult by generating an `affine.apply`
46-
/// op and fold it.
47-
static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc,
48-
AffineExpr expr,
49-
SmallVector<OpFoldResult> &symbols) {
50-
AffineMap m = AffineMap::get(0, symbols.size(), expr);
51-
return makeComposedFoldedAffineApply(builder, loc, m, symbols);
52-
}
53-
54-
LogicalResult tensor::mergeOffsetsSizesAndStrides(
55-
OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets,
56-
ArrayRef<OpFoldResult> producerSizes,
57-
ArrayRef<OpFoldResult> producerStrides,
58-
const llvm::SmallBitVector &droppedProducerDims,
59-
ArrayRef<OpFoldResult> consumerOffsets,
60-
ArrayRef<OpFoldResult> consumerSizes,
61-
ArrayRef<OpFoldResult> consumerStrides,
62-
SmallVector<OpFoldResult> &combinedOffsets,
63-
SmallVector<OpFoldResult> &combinedSizes,
64-
SmallVector<OpFoldResult> &combinedStrides) {
65-
combinedOffsets.resize(producerOffsets.size());
66-
combinedSizes.resize(producerOffsets.size());
67-
combinedStrides.resize(producerOffsets.size());
68-
unsigned consumerPos = 0;
69-
for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) {
70-
if (droppedProducerDims.test(i)) {
71-
// For dropped dims, get the values from the producer.
72-
combinedOffsets[i] = producerOffsets[i];
73-
combinedSizes[i] = producerSizes[i];
74-
combinedStrides[i] = producerStrides[i];
75-
continue;
76-
}
77-
SmallVector<OpFoldResult> offsetSymbols, strideSymbols;
78-
// The combined offset is computed as
79-
// producer_offset + consumer_offset * producer_strides.
80-
combinedOffsets[i] =
81-
getOpFoldResult(builder, loc,
82-
add(mul(consumerOffsets[consumerPos],
83-
producerStrides[i], offsetSymbols),
84-
producerOffsets[i], offsetSymbols),
85-
offsetSymbols);
86-
combinedSizes[i] = consumerSizes[consumerPos];
87-
// The combined stride is computed as
88-
// consumer_stride * producer_stride.
89-
combinedStrides[i] = getOpFoldResult(
90-
builder, loc,
91-
mul(consumerStrides[consumerPos], producerStrides[i], strideSymbols),
92-
strideSymbols);
93-
consumerPos++;
94-
}
95-
return success();
96-
}
97-
98-
LogicalResult tensor::mergeOffsetsSizesAndStrides(
99-
OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer,
100-
OffsetSizeAndStrideOpInterface consumer,
101-
const llvm::SmallBitVector &droppedProducerDims,
102-
SmallVector<OpFoldResult> &combinedOffsets,
103-
SmallVector<OpFoldResult> &combinedSizes,
104-
SmallVector<OpFoldResult> &combinedStrides) {
105-
SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
106-
SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes();
107-
SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
108-
SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
109-
SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes();
110-
SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
111-
return tensor::mergeOffsetsSizesAndStrides(
112-
builder, loc, producerOffsets, producerSizes, producerStrides,
113-
droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
114-
combinedOffsets, combinedSizes, combinedStrides);
42+
return foldedOffsets;
11543
}
11644

11745
namespace {
@@ -125,15 +53,24 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
12553
if (!prevOp)
12654
return failure();
12755

128-
SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
129-
if (failed(mergeOffsetsSizesAndStrides(rewriter, nextOp.getLoc(), prevOp,
130-
nextOp, prevOp.getDroppedDims(),
131-
newOffsets, newSizes, newStrides)))
56+
if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
13257
return failure();
13358

134-
rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
135-
prevOp.getSource(), newOffsets,
136-
newSizes, newStrides);
59+
auto prevResultType = prevOp.getType().cast<ShapedType>();
60+
if (prevOp.getSourceType().getRank() != prevResultType.getRank())
61+
return rewriter.notifyMatchFailure(
62+
prevOp, "rank-reducing producder case unimplemented");
63+
64+
Location loc = nextOp.getLoc();
65+
66+
SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets();
67+
SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets();
68+
SmallVector<OpFoldResult> foldedOffsets =
69+
mergeOffsets(loc, prevOffsets, nextOffsets, rewriter);
70+
71+
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
72+
nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets,
73+
nextOp.getMixedSizes(), nextOp.getMixedStrides());
13774
return success();
13875
}
13976
};

mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@ func.func @extract_slice_same_rank(
99

1010
// CHECK-LABEL: func.func @extract_slice_same_rank
1111
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
12-
// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
12+
// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]]
1313
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
1414
// CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32>
1515

16-
// -----
17-
1816
func.func @extract_slice_rank_reducing_consumer(
1917
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
2018
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
@@ -25,36 +23,15 @@ func.func @extract_slice_rank_reducing_consumer(
2523
// CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
2624
// CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
2725

28-
// -----
29-
3026
func.func @extract_slice_rank_reducing_producer(
3127
%src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
3228
%0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
3329
%1 = tensor.extract_slice %0[7, %offset1] [8, %size1] [1, 1] : tensor<128x?xf32> to tensor<8x?xf32>
3430
return %1: tensor<8x?xf32>
3531
}
3632

37-
// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
38-
// CHECK-SAME: (%[[SRC:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
39-
// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
40-
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][0, 8, 2, %[[OFFSET]]] [1, 8, 1, %[[SIZE1]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<8x?xf32>
41-
// CHECK: return %[[EXTRACT]] : tensor<8x?xf32>
42-
43-
// -----
44-
45-
func.func @extract_slice_non_one_stride(
46-
%src: tensor<?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index, %stride0: index, %stride1: index) -> tensor<?xf32> {
47-
%0 = tensor.extract_slice %src[%offset0] [%size0] [%stride0] : tensor<?xf32> to tensor<?xf32>
48-
%1 = tensor.extract_slice %0[%offset1] [%size1] [%stride1] : tensor<?xf32> to tensor<?xf32>
49-
return %1: tensor<?xf32>
50-
}
51-
52-
// CHECK-LABEL: func.func @extract_slice_non_one_stride
53-
// CHECK-SAME: (%[[SRC:.+]]: tensor<?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index, %[[STRIDE0:.+]]: index, %[[STRIDE1:.+]]: index)
54-
// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>()[%[[OFFSET1]], %[[STRIDE0]], %[[OFFSET0]]]
55-
// CHECK: %[[STRIDE:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%[[STRIDE1]], %[[STRIDE0]]]
56-
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][%[[OFFSET]]] [%[[SIZE1]]] [%[[STRIDE]]] : tensor<?xf32> to tensor<?xf32>
57-
// CHECK: return %[[EXTRACT]] : tensor<?xf32>
33+
// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
34+
// CHECK-COUNT-2: tensor.extract_slice
5835

5936
// -----
6037

@@ -70,8 +47,6 @@ func.func @insert_slice_rank_reducing(
7047
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1]
7148
// CHECK: return %[[INSERT]]
7249

73-
// -----
74-
7550
func.func @insert_slice_rank_reducing_dynamic_shape(
7651
%dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> {
7752
%0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32>

0 commit comments

Comments
 (0)