Skip to content

Commit c1fd430

Browse files
committed
[mlir] Add basic support for dynamic tensor results in TensorToBuffers.cpp.
The simplest case is when the indexing maps are DimIds in every component. This covers cwise ops. Also: * Expose populateConvertLinalgOnTensorsToBuffersPatterns in Transforms.h * Expose emitLoopRanges in Transforms.h Differential Revision: https://reviews.llvm.org/D88781
1 parent aa47962 commit c1fd430

File tree

4 files changed

+288
-154
lines changed

4 files changed

+288
-154
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#include "llvm/ADT/SmallBitVector.h"
1717

1818
namespace mlir {
19+
20+
class BufferAssignmentTypeConverter;
21+
1922
namespace linalg {
2023

2124
struct LinalgFusionOptions;
@@ -45,6 +48,12 @@ void populateConvVectorizationPatterns(
4548
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
4649
ArrayRef<int64_t> tileSizes);
4750

51+
/// Populates the given list with patterns to convert Linalg operations on
52+
/// tensors to buffers.
53+
void populateConvertLinalgOnTensorsToBuffersPatterns(
54+
MLIRContext *context, BufferAssignmentTypeConverter *converter,
55+
OwningRewritePatternList *patterns);
56+
4857
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
4958
/// and permute the loop nest according to `interchangeVector`
5059
/// The permutation is expressed as a list of integers that specify
@@ -246,6 +255,16 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
246255
LinalgPromotionOptions options,
247256
OperationFolder *folder = nullptr);
248257

258+
/// Creates a number of ranges equal to the number of dimensions in the `map`.
259+
/// The returned ranges correspond to the loop ranges, in the proper order, for
260+
/// which new loops will be created.
261+
/// The function supports only maps that are invertible and have results of type
262+
/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr).
263+
/// It expects a non-inverted, concatenated map and last values in
264+
/// allViewSizes will be applied to the symbols in the map if it contains any.
265+
SmallVector<Range, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
266+
ValueRange viewSizes);
267+
249268
/// Emit a suitable vector form for a Linalg op with fully static shape.
250269
void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
251270

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Lines changed: 64 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -58,77 +58,6 @@ static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
5858
: SmallVector<Value, 4>(ivs.begin(), ivs.end());
5959
}
6060

61-
/// Creates a number of ranges equal to the number of dimensions in the `map`.
62-
/// The returned ranges correspond to the loop ranges, in the proper order, for
63-
/// which new loops will be created.
64-
/// The function supports only maps that are invertible and have results of type
65-
/// DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr).
66-
/// It expects a non-inverted, concatenated map and last values in
67-
/// allViewSizes will be applied to the symbols in the map if it contains any.
68-
static SmallVector<Range, 4> emitLoopRanges(OpBuilder &b, Location loc,
69-
AffineMap map,
70-
ValueRange viewSizes) {
71-
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
72-
unsigned numSym = map.getNumSymbols();
73-
assert(viewSizes.size() == numRes + numSym &&
74-
"viewSizes must contain sizes of all views and values for symbols");
75-
SmallVector<Range, 4> res(numDims);
76-
for (unsigned idx = 0; idx < numRes; ++idx) {
77-
auto result = map.getResult(idx);
78-
if (auto d = result.dyn_cast<AffineDimExpr>()) {
79-
if (res[d.getPosition()].offset)
80-
continue;
81-
res[d.getPosition()] =
82-
Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)};
83-
}
84-
85-
// If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
86-
// then the bounds are:
87-
// (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1).
88-
// where size(n) is applied to the symbol s.
89-
// This is done statically now.
90-
if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
91-
auto lhs = binOp.getLHS().dyn_cast<AffineBinaryOpExpr>();
92-
auto rhs = binOp.getRHS().dyn_cast<AffineBinaryOpExpr>();
93-
if (!lhs || !rhs || binOp.getKind() != AffineExprKind::Add ||
94-
lhs.getKind() != AffineExprKind::Add ||
95-
rhs.getKind() != mlir::AffineExprKind::Mul)
96-
continue;
97-
98-
auto m = lhs.getLHS().dyn_cast<AffineDimExpr>();
99-
auto n = lhs.getRHS().dyn_cast<AffineDimExpr>();
100-
auto fDiv = rhs.getLHS().dyn_cast<AffineBinaryOpExpr>();
101-
auto minusOne = rhs.getRHS().dyn_cast<AffineConstantExpr>();
102-
if (!m || !n || !fDiv || !minusOne ||
103-
fDiv.getKind() != AffineExprKind::FloorDiv ||
104-
fDiv.getLHS().getKind() != AffineExprKind::SymbolId ||
105-
fDiv.getRHS().getKind() != AffineExprKind::Constant)
106-
continue;
107-
108-
auto s = fDiv.getLHS().dyn_cast<AffineSymbolExpr>();
109-
if (minusOne.getValue() != -1)
110-
continue;
111-
112-
int mPos = m.getPosition();
113-
AffineExpr one = getAffineConstantExpr(1, s.getContext());
114-
AffineExpr sizeOfM = getAffineSymbolExpr(numSym, s.getContext());
115-
// Construction of upper bound (size(m) + s floordiv 2 - s + 1).
116-
AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s;
117-
AffineMap fromMap = AffineMap::get(numDims, numSym + 1, fDiv);
118-
AffineMap toMap = AffineMap::get(numDims, numSym + 1, upperOffsetExpr);
119-
SmallVector<Value, 8> values(viewSizes.begin(),
120-
viewSizes.begin() + numDims);
121-
values.insert(values.end(), viewSizes.begin() + numRes, viewSizes.end());
122-
values.push_back(viewSizes[mPos]);
123-
// Construction of the lower bound (s floordiv 2).
124-
Value from = applyMapToValues(b, loc, fromMap, values).front();
125-
Value to = applyMapToValues(b, loc, toMap, values).front();
126-
res[mPos] = Range{from, to, std_constant_index(1)};
127-
}
128-
}
129-
return res;
130-
}
131-
13261
template <typename IndexedValueType, typename OpType>
13362
static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
13463
ArrayRef<SmallVector<Value, 8>> indexing,
@@ -708,6 +637,70 @@ static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
708637
llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
709638
}
710639

640+
SmallVector<Range, 4> mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc,
641+
AffineMap map,
642+
ValueRange viewSizes) {
643+
unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
644+
unsigned numSym = map.getNumSymbols();
645+
assert(viewSizes.size() == numRes + numSym &&
646+
"viewSizes must contain sizes of all views and values for symbols");
647+
SmallVector<Range, 4> res(numDims);
648+
for (unsigned idx = 0; idx < numRes; ++idx) {
649+
auto result = map.getResult(idx);
650+
if (auto d = result.dyn_cast<AffineDimExpr>()) {
651+
if (res[d.getPosition()].offset)
652+
continue;
653+
res[d.getPosition()] =
654+
Range{std_constant_index(0), viewSizes[idx], std_constant_index(1)};
655+
}
656+
657+
// If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
658+
// then the bounds are:
659+
// (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1).
660+
// where size(n) is applied to the symbol s.
661+
// This is done statically now.
662+
if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
663+
auto lhs = binOp.getLHS().dyn_cast<AffineBinaryOpExpr>();
664+
auto rhs = binOp.getRHS().dyn_cast<AffineBinaryOpExpr>();
665+
if (!lhs || !rhs || binOp.getKind() != AffineExprKind::Add ||
666+
lhs.getKind() != AffineExprKind::Add ||
667+
rhs.getKind() != mlir::AffineExprKind::Mul)
668+
continue;
669+
670+
auto m = lhs.getLHS().dyn_cast<AffineDimExpr>();
671+
auto n = lhs.getRHS().dyn_cast<AffineDimExpr>();
672+
auto fDiv = rhs.getLHS().dyn_cast<AffineBinaryOpExpr>();
673+
auto minusOne = rhs.getRHS().dyn_cast<AffineConstantExpr>();
674+
if (!m || !n || !fDiv || !minusOne ||
675+
fDiv.getKind() != AffineExprKind::FloorDiv ||
676+
fDiv.getLHS().getKind() != AffineExprKind::SymbolId ||
677+
fDiv.getRHS().getKind() != AffineExprKind::Constant)
678+
continue;
679+
680+
auto s = fDiv.getLHS().dyn_cast<AffineSymbolExpr>();
681+
if (minusOne.getValue() != -1)
682+
continue;
683+
684+
int mPos = m.getPosition();
685+
AffineExpr one = getAffineConstantExpr(1, s.getContext());
686+
AffineExpr sizeOfM = getAffineSymbolExpr(numSym, s.getContext());
687+
// Construction of upper bound (size(m) + s floordiv 2 - s + 1).
688+
AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s;
689+
AffineMap fromMap = AffineMap::get(numDims, numSym + 1, fDiv);
690+
AffineMap toMap = AffineMap::get(numDims, numSym + 1, upperOffsetExpr);
691+
SmallVector<Value, 8> values(viewSizes.begin(),
692+
viewSizes.begin() + numDims);
693+
values.insert(values.end(), viewSizes.begin() + numRes, viewSizes.end());
694+
values.push_back(viewSizes[mPos]);
695+
// Construction of the lower bound (s floordiv 2).
696+
Value from = applyMapToValues(b, loc, fromMap, values).front();
697+
Value to = applyMapToValues(b, loc, toMap, values).front();
698+
res[mPos] = Range{from, to, std_constant_index(1)};
699+
}
700+
}
701+
return res;
702+
}
703+
711704
/// Emits a loop nest with the proper body for `op`.
712705
template <typename LoopTy>
713706
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,

0 commit comments

Comments
 (0)