7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Dialect/Affine/IR/AffineOps.h"
10
- #include " mlir/Dialect/Arithmetic/Utils/Utils.h"
11
10
#include " mlir/Dialect/Tensor/IR/Tensor.h"
11
+ #include " mlir/Dialect/Tensor/Transforms/TransformUtils.h"
12
12
#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
13
13
#include " mlir/IR/BuiltinTypes.h"
14
14
#include " mlir/IR/OpDefinition.h"
17
17
using namespace mlir ;
18
18
using namespace mlir ::tensor;
19
19
20
- // / Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
21
- // / returns the results.
22
- static SmallVector<OpFoldResult> mergeOffsets (Location loc,
23
- ArrayRef<OpFoldResult> offsets1,
24
- ArrayRef<OpFoldResult> offsets2,
25
- OpBuilder &builder) {
26
- SmallVector<OpFoldResult> foldedOffsets;
27
- assert (offsets1.size () == offsets2.size ());
28
- foldedOffsets.reserve (offsets1.size ());
29
-
30
- AffineExpr dim1, dim2;
31
- bindDims (builder.getContext (), dim1, dim2);
32
-
33
- for (const auto &pair : llvm::zip (offsets1, offsets2)) {
34
- auto offset0 =
35
- getValueOrCreateConstantIndexOp (builder, loc, std::get<0 >(pair));
36
- auto offset1 =
37
- getValueOrCreateConstantIndexOp (builder, loc, std::get<1 >(pair));
38
- auto foldedOffset =
39
- makeComposedAffineApply (builder, loc, dim1 + dim2, {offset0, offset1});
40
- foldedOffsets.push_back (foldedOffset.getResult ());
20
+ // / Creates AffineExpr from `ofr`: if the OpFoldResult is a Value, creates a
21
+ // / AffineSymbolExpr and appends it to `symbols`; otherwise creates a
22
+ // / AffineConstantExpr.
23
+ static AffineExpr getAffineExpr (OpFoldResult ofr,
24
+ SmallVector<OpFoldResult> &symbols) {
25
+ if (auto attr = ofr.dyn_cast <Attribute>()) {
26
+ return getAffineConstantExpr (attr.cast <IntegerAttr>().getInt (),
27
+ attr.getContext ());
41
28
}
42
- return foldedOffsets;
29
+ Value v = ofr.get <Value>();
30
+ AffineExpr expr = getAffineSymbolExpr (symbols.size (), v.getContext ());
31
+ symbols.push_back (v);
32
+ return expr;
33
+ }
34
+
35
+ // / Builds the AffineExpr incrementally for arithmetic operations.
36
+ static AffineExpr add (AffineExpr expr, OpFoldResult ofr,
37
+ SmallVector<OpFoldResult> &symbols) {
38
+ return expr + getAffineExpr (ofr, symbols);
39
+ }
40
+ static AffineExpr mul (OpFoldResult lhs, OpFoldResult rhs,
41
+ SmallVector<OpFoldResult> &symbols) {
42
+ return getAffineExpr (lhs, symbols) * getAffineExpr (rhs, symbols);
43
+ }
44
+
45
+ // / Converts an AffineExpr to OpFoldResult by generating an `affine.apply`
46
+ // / op and fold it.
47
+ static OpFoldResult getOpFoldResult (OpBuilder &builder, Location loc,
48
+ AffineExpr expr,
49
+ SmallVector<OpFoldResult> &symbols) {
50
+ AffineMap m = AffineMap::get (0 , symbols.size (), expr);
51
+ return makeComposedFoldedAffineApply (builder, loc, m, symbols);
52
+ }
53
+
54
+ LogicalResult tensor::mergeOffsetsSizesAndStrides (
55
+ OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets,
56
+ ArrayRef<OpFoldResult> producerSizes,
57
+ ArrayRef<OpFoldResult> producerStrides,
58
+ const llvm::SmallBitVector &droppedProducerDims,
59
+ ArrayRef<OpFoldResult> consumerOffsets,
60
+ ArrayRef<OpFoldResult> consumerSizes,
61
+ ArrayRef<OpFoldResult> consumerStrides,
62
+ SmallVector<OpFoldResult> &combinedOffsets,
63
+ SmallVector<OpFoldResult> &combinedSizes,
64
+ SmallVector<OpFoldResult> &combinedStrides) {
65
+ combinedOffsets.resize (producerOffsets.size ());
66
+ combinedSizes.resize (producerOffsets.size ());
67
+ combinedStrides.resize (producerOffsets.size ());
68
+ unsigned consumerPos = 0 ;
69
+ for (auto i : llvm::seq<unsigned >(0 , producerOffsets.size ())) {
70
+ if (droppedProducerDims.test (i)) {
71
+ // For dropped dims, get the values from the producer.
72
+ combinedOffsets[i] = producerOffsets[i];
73
+ combinedSizes[i] = producerSizes[i];
74
+ combinedStrides[i] = producerStrides[i];
75
+ continue ;
76
+ }
77
+ SmallVector<OpFoldResult> offsetSymbols, strideSymbols;
78
+ // The combined offset is computed as
79
+ // producer_offset + consumer_offset * producer_strides.
80
+ combinedOffsets[i] =
81
+ getOpFoldResult (builder, loc,
82
+ add (mul (consumerOffsets[consumerPos],
83
+ producerStrides[i], offsetSymbols),
84
+ producerOffsets[i], offsetSymbols),
85
+ offsetSymbols);
86
+ combinedSizes[i] = consumerSizes[consumerPos];
87
+ // The combined stride is computed as
88
+ // consumer_stride * producer_stride.
89
+ combinedStrides[i] = getOpFoldResult (
90
+ builder, loc,
91
+ mul (consumerStrides[consumerPos], producerStrides[i], strideSymbols),
92
+ strideSymbols);
93
+ consumerPos++;
94
+ }
95
+ return success ();
96
+ }
97
+
98
+ LogicalResult tensor::mergeOffsetsSizesAndStrides (
99
+ OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer,
100
+ OffsetSizeAndStrideOpInterface consumer,
101
+ const llvm::SmallBitVector &droppedProducerDims,
102
+ SmallVector<OpFoldResult> &combinedOffsets,
103
+ SmallVector<OpFoldResult> &combinedSizes,
104
+ SmallVector<OpFoldResult> &combinedStrides) {
105
+ SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets ();
106
+ SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes ();
107
+ SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides ();
108
+ SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets ();
109
+ SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes ();
110
+ SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides ();
111
+ return tensor::mergeOffsetsSizesAndStrides (
112
+ builder, loc, producerOffsets, producerSizes, producerStrides,
113
+ droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
114
+ combinedOffsets, combinedSizes, combinedStrides);
43
115
}
44
116
45
117
namespace {
@@ -53,24 +125,15 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
53
125
if (!prevOp)
54
126
return failure ();
55
127
56
- if (!prevOp.hasUnitStride () || !nextOp.hasUnitStride ())
128
+ SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
129
+ if (failed (mergeOffsetsSizesAndStrides (rewriter, nextOp.getLoc (), prevOp,
130
+ nextOp, prevOp.getDroppedDims (),
131
+ newOffsets, newSizes, newStrides)))
57
132
return failure ();
58
133
59
- auto prevResultType = prevOp.getType ().cast <ShapedType>();
60
- if (prevOp.getSourceType ().getRank () != prevResultType.getRank ())
61
- return rewriter.notifyMatchFailure (
62
- prevOp, " rank-reducing producder case unimplemented" );
63
-
64
- Location loc = nextOp.getLoc ();
65
-
66
- SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets ();
67
- SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets ();
68
- SmallVector<OpFoldResult> foldedOffsets =
69
- mergeOffsets (loc, prevOffsets, nextOffsets, rewriter);
70
-
71
- rewriter.replaceOpWithNewOp <ExtractSliceOp>(
72
- nextOp, nextOp.getType (), prevOp.getSource (), foldedOffsets,
73
- nextOp.getMixedSizes (), nextOp.getMixedStrides ());
134
+ rewriter.replaceOpWithNewOp <ExtractSliceOp>(nextOp, nextOp.getType (),
135
+ prevOp.getSource (), newOffsets,
136
+ newSizes, newStrides);
74
137
return success ();
75
138
}
76
139
};
0 commit comments