7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Dialect/Affine/IR/AffineOps.h"
10
+ #include " mlir/Dialect/Arithmetic/Utils/Utils.h"
10
11
#include " mlir/Dialect/Tensor/IR/Tensor.h"
11
- #include " mlir/Dialect/Tensor/Transforms/TransformUtils.h"
12
12
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
13
13
#include " mlir/IR/BuiltinTypes.h"
14
14
#include " mlir/IR/OpDefinition.h"
17
17
using namespace mlir ;
18
18
using namespace mlir ::tensor;
19
19
20
- // / Creates AffineExpr from `ofr`: if the OpFoldResult is a Value, creates a
21
- // / AffineSymbolExpr and appends it to `symbols`; otherwise creates a
22
- // / AffineConstantExpr.
23
- static AffineExpr getAffineExpr (OpFoldResult ofr,
24
- SmallVector<OpFoldResult> &symbols) {
25
- if (auto attr = ofr.dyn_cast <Attribute>()) {
26
- return getAffineConstantExpr (attr.cast <IntegerAttr>().getInt (),
27
- attr.getContext ());
20
+ // / Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
21
+ // / returns the results.
22
+ static SmallVector<OpFoldResult> mergeOffsets (Location loc,
23
+ ArrayRef<OpFoldResult> offsets1,
24
+ ArrayRef<OpFoldResult> offsets2,
25
+ OpBuilder &builder) {
26
+ SmallVector<OpFoldResult> foldedOffsets;
27
+ assert (offsets1.size () == offsets2.size ());
28
+ foldedOffsets.reserve (offsets1.size ());
29
+
30
+ AffineExpr dim1, dim2;
31
+ bindDims (builder.getContext (), dim1, dim2);
32
+
33
+ for (const auto &pair : llvm::zip (offsets1, offsets2)) {
34
+ auto offset0 =
35
+ getValueOrCreateConstantIndexOp (builder, loc, std::get<0 >(pair));
36
+ auto offset1 =
37
+ getValueOrCreateConstantIndexOp (builder, loc, std::get<1 >(pair));
38
+ auto foldedOffset =
39
+ makeComposedAffineApply (builder, loc, dim1 + dim2, {offset0, offset1});
40
+ foldedOffsets.push_back (foldedOffset.getResult ());
28
41
}
29
- Value v = ofr.get <Value>();
30
- AffineExpr expr = getAffineSymbolExpr (symbols.size (), v.getContext ());
31
- symbols.push_back (v);
32
- return expr;
33
- }
34
-
35
- // / Builds the AffineExpr incrementally for arithmetic operations.
36
- static AffineExpr add (AffineExpr expr, OpFoldResult ofr,
37
- SmallVector<OpFoldResult> &symbols) {
38
- return expr + getAffineExpr (ofr, symbols);
39
- }
40
- static AffineExpr mul (OpFoldResult lhs, OpFoldResult rhs,
41
- SmallVector<OpFoldResult> &symbols) {
42
- return getAffineExpr (lhs, symbols) * getAffineExpr (rhs, symbols);
43
- }
44
-
45
- // / Converts an AffineExpr to OpFoldResult by generating an `affine.apply`
46
- // / op and fold it.
47
- static OpFoldResult getOpFoldResult (OpBuilder &builder, Location loc,
48
- AffineExpr expr,
49
- SmallVector<OpFoldResult> &symbols) {
50
- AffineMap m = AffineMap::get (0 , symbols.size (), expr);
51
- return makeComposedFoldedAffineApply (builder, loc, m, symbols);
52
- }
53
-
54
- LogicalResult tensor::mergeOffsetsSizesAndStrides (
55
- OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets,
56
- ArrayRef<OpFoldResult> producerSizes,
57
- ArrayRef<OpFoldResult> producerStrides,
58
- const llvm::SmallBitVector &droppedProducerDims,
59
- ArrayRef<OpFoldResult> consumerOffsets,
60
- ArrayRef<OpFoldResult> consumerSizes,
61
- ArrayRef<OpFoldResult> consumerStrides,
62
- SmallVector<OpFoldResult> &combinedOffsets,
63
- SmallVector<OpFoldResult> &combinedSizes,
64
- SmallVector<OpFoldResult> &combinedStrides) {
65
- combinedOffsets.resize (producerOffsets.size ());
66
- combinedSizes.resize (producerOffsets.size ());
67
- combinedStrides.resize (producerOffsets.size ());
68
- unsigned consumerPos = 0 ;
69
- for (auto i : llvm::seq<unsigned >(0 , producerOffsets.size ())) {
70
- if (droppedProducerDims.test (i)) {
71
- // For dropped dims, get the values from the producer.
72
- combinedOffsets[i] = producerOffsets[i];
73
- combinedSizes[i] = producerSizes[i];
74
- combinedStrides[i] = producerStrides[i];
75
- continue ;
76
- }
77
- SmallVector<OpFoldResult> offsetSymbols, strideSymbols;
78
- // The combined offset is computed as
79
- // producer_offset + consumer_offset * producer_strides.
80
- combinedOffsets[i] =
81
- getOpFoldResult (builder, loc,
82
- add (mul (consumerOffsets[consumerPos],
83
- producerStrides[i], offsetSymbols),
84
- producerOffsets[i], offsetSymbols),
85
- offsetSymbols);
86
- combinedSizes[i] = consumerSizes[consumerPos];
87
- // The combined stride is computed as
88
- // consumer_stride * producer_stride.
89
- combinedStrides[i] = getOpFoldResult (
90
- builder, loc,
91
- mul (consumerStrides[consumerPos], producerStrides[i], strideSymbols),
92
- strideSymbols);
93
- consumerPos++;
94
- }
95
- return success ();
96
- }
97
-
98
- LogicalResult tensor::mergeOffsetsSizesAndStrides (
99
- OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer,
100
- OffsetSizeAndStrideOpInterface consumer,
101
- const llvm::SmallBitVector &droppedProducerDims,
102
- SmallVector<OpFoldResult> &combinedOffsets,
103
- SmallVector<OpFoldResult> &combinedSizes,
104
- SmallVector<OpFoldResult> &combinedStrides) {
105
- SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets ();
106
- SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes ();
107
- SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides ();
108
- SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets ();
109
- SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes ();
110
- SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides ();
111
- return tensor::mergeOffsetsSizesAndStrides (
112
- builder, loc, producerOffsets, producerSizes, producerStrides,
113
- droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
114
- combinedOffsets, combinedSizes, combinedStrides);
42
+ return foldedOffsets;
115
43
}
116
44
117
45
namespace {
@@ -125,15 +53,24 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
125
53
if (!prevOp)
126
54
return failure ();
127
55
128
- SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
129
- if (failed (mergeOffsetsSizesAndStrides (rewriter, nextOp.getLoc (), prevOp,
130
- nextOp, prevOp.getDroppedDims (),
131
- newOffsets, newSizes, newStrides)))
56
+ if (!prevOp.hasUnitStride () || !nextOp.hasUnitStride ())
132
57
return failure ();
133
58
134
- rewriter.replaceOpWithNewOp <ExtractSliceOp>(nextOp, nextOp.getType (),
135
- prevOp.getSource (), newOffsets,
136
- newSizes, newStrides);
59
+ auto prevResultType = prevOp.getType ().cast <ShapedType>();
60
+ if (prevOp.getSourceType ().getRank () != prevResultType.getRank ())
61
+ return rewriter.notifyMatchFailure (
62
+ prevOp, " rank-reducing producder case unimplemented" );
63
+
64
+ Location loc = nextOp.getLoc ();
65
+
66
+ SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets ();
67
+ SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets ();
68
+ SmallVector<OpFoldResult> foldedOffsets =
69
+ mergeOffsets (loc, prevOffsets, nextOffsets, rewriter);
70
+
71
+ rewriter.replaceOpWithNewOp <ExtractSliceOp>(
72
+ nextOp, nextOp.getType (), prevOp.getSource (), foldedOffsets,
73
+ nextOp.getMixedSizes (), nextOp.getMixedStrides ());
137
74
return success ();
138
75
}
139
76
};
0 commit comments