Skip to content

Commit 06a65ce

Browse files
authored
[mlir][sparse] schedule sparse kernels in a separate pass from sparsification. (#72423)
1 parent dce7a7c commit 06a65ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+674
-723
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
127127
return hasAnySparseOperand(op) || hasAnySparseResult(op);
128128
}
129129

130+
/// Returns true iff MLIR operation has any sparse tensor with non-identity
131+
/// dim2lvl maps.
132+
bool hasAnyNonIdentityOperandsOrResults(Operation *op);
133+
130134
//
131135
// Inference.
132136
//

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,16 @@ bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
875875
return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
876876
}
877877

878+
bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
879+
auto hasNonIdentityMap = [](Value v) {
880+
auto stt = tryGetSparseTensorType(v);
881+
return stt && !stt->isIdentity();
882+
};
883+
884+
return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
885+
llvm::any_of(op->getResults(), hasNonIdentityMap);
886+
}
887+
878888
Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
879889
// We only consider COO region with at least two levels for the purpose
880890
// of AOS storage optimization.

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
22
BufferizableOpInterfaceImpl.cpp
33
CodegenEnv.cpp
44
CodegenUtils.cpp
5+
IterationGraphSorter.cpp
56
LoopEmitter.cpp
67
SparseBufferRewriting.cpp
78
SparseGPUCodegen.cpp

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
5757
loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u),
5858
insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(),
5959
redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
60-
redValidLexInsert() {}
60+
redValidLexInsert() {
61+
// TODO: remove topSort, loops should be already sorted by previous pass.
62+
for (unsigned l = 0; l < latticeMerger.getNumLoops(); l++)
63+
topSort.push_back(l);
64+
}
6165

6266
LogicalResult CodegenEnv::initTensorExp() {
6367
// Builds the tensor expression for the Linalg operation in SSA form.
@@ -181,36 +185,23 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
181185
// Accept "truly dynamic" if the output tensor materializes uninitialized
182186
// into the computation and insertions occur in lexicographic index order.
183187
sparseOut = lhs;
184-
return isMaterializing(lhs->get());
185-
}
186188

187-
bool CodegenEnv::isAdmissibleTopoOrder() {
188-
if (!hasSparseOutput())
189-
return true;
190-
191-
OpOperand *lhs = linalgOp.getDpsInitOperand(0);
192-
// Accept "truly dynamic" if the output tensor materializes uninitialized
193-
// into the computation and insertions occur in lexicographic index order.
194-
LoopOrd nest = 0;
189+
// Find the outermost parallel nest to determine whether compress/expand is
190+
// needed.
191+
outerParNest = 0;
195192
const auto iteratorTypes = linalgOp.getIteratorTypesArray();
196193
assert(topSortSize() == latticeMerger.getNumLoops());
197194
for (const LoopId i : topSort) {
198-
if (!latticeMerger.isFilterLoop(i)) {
199-
// We only count non-filter loops as filter loops should be considered
200-
// a special type of parallel loops.
201-
if (linalg::isReductionIterator(iteratorTypes[i]))
202-
break; // terminate at first reduction
203-
nest++;
204-
}
205-
}
206-
// Determine admissible dynamic insertion situations:
207-
// (1) fully injective, since there are no reductions,
208-
// (2) admissible 1-d expansion in innermost dimension.
209-
if (static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1) {
210-
outerParNest = nest;
211-
return true;
195+
if (linalg::isReductionIterator(iteratorTypes[i]))
196+
break; // terminate at first reduction
197+
outerParNest++;
212198
}
213-
return false;
199+
200+
// Inadmissible kernel should have already been rejected by the previous
201+
// path during loop scheduling.
202+
assert(static_cast<int64_t>(outerParNest) >=
203+
linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
204+
return isMaterializing(lhs->get());
214205
}
215206

216207
//===----------------------------------------------------------------------===//
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
//===- LoopScheduler.cpp -------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "IterationGraphSorter.h"
10+
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
12+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13+
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
15+
#include "mlir/IR/AffineExprVisitor.h"
16+
#include "mlir/IR/BuiltinTypes.h"
17+
18+
using namespace mlir;
19+
using namespace mlir::sparse_tensor;
20+
21+
namespace {
22+
23+
/// A helper class that visits an affine expression and tries to find an
24+
/// AffineDimExpr to which the corresponding iterator from a GenericOp matches
25+
/// the desired iterator type.
26+
/// If there is no matched iterator type, returns the first DimExpr in the
27+
/// expression.
28+
class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
29+
public:
30+
explicit AffineDimFinder(ArrayRef<utils::IteratorType> itTypes)
31+
: iterTypes(itTypes) {}
32+
33+
// Override method from AffineExprVisitor.
34+
void visitDimExpr(AffineDimExpr expr) {
35+
if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()])
36+
pickedDim = expr;
37+
}
38+
39+
/// Set the desired iterator type that we want to pick.
40+
void setPickedIterType(utils::IteratorType iterType) {
41+
pickIterType = iterType;
42+
}
43+
44+
/// Get the desired AffineDimExpr.
45+
AffineDimExpr getDimExpr() const {
46+
return llvm::cast<AffineDimExpr>(pickedDim);
47+
}
48+
49+
void walkPostOrder(AffineExpr expr) {
50+
pickedDim = nullptr;
51+
AffineExprVisitor<AffineDimFinder>::walkPostOrder(expr);
52+
}
53+
54+
private:
55+
/// The picked AffineDimExpr after visit.
56+
AffineExpr pickedDim;
57+
/// The iterator type that we want.
58+
utils::IteratorType pickIterType;
59+
/// The mapping between dim=>iterator type.
60+
ArrayRef<utils::IteratorType> iterTypes;
61+
};
62+
63+
// Flattens an affine expression into a list of AffineDimExprs.
64+
struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
65+
// Overrides method from AffineExprVisitor.
66+
void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
67+
SmallVector<AffineDimExpr> dims;
68+
};
69+
70+
} // namespace
71+
72+
inline static bool includesAny(SortMask mask1, SortMask mask2) {
73+
return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
74+
}
75+
76+
inline static bool includesDenseInput(SortMask mask) {
77+
return includesAny(mask, SortMask::kIncludeDenseInput);
78+
}
79+
80+
inline static bool includesDenseOutput(SortMask mask) {
81+
return includesAny(mask, SortMask::kIncludeDenseOutput);
82+
}
83+
84+
/// A helper to compute a topological sort. O(n^2) time complexity
85+
/// as we use adj matrix for the graph.
86+
/// The sorted result will put the first Reduction iterator to the
87+
/// latest possible position.
88+
AffineMap IterationGraphSorter::topoSort() {
89+
std::vector<unsigned> redIt; // reduce iterator with 0 degree
90+
std::vector<unsigned> parIt; // parallel iterator with 0 degree
91+
const unsigned numLoops = getNumLoops();
92+
for (unsigned i = 0; i < numLoops; i++) {
93+
if (inDegree[i] == 0) {
94+
if (iterTypes[i] == utils::IteratorType::reduction)
95+
redIt.push_back(i);
96+
else
97+
parIt.push_back(i);
98+
}
99+
}
100+
101+
SmallVector<unsigned> loopOrder;
102+
while (!redIt.empty() || !parIt.empty()) {
103+
// We always prefer parallel loop over reduction loop because putting
104+
// reduction loop early might make the loop sequence inadmissible.
105+
auto &it = !parIt.empty() ? parIt : redIt;
106+
auto src = it.back();
107+
loopOrder.push_back(src);
108+
it.pop_back();
109+
// Update in-degree, and push 0-degree node into worklist.
110+
for (unsigned dst = 0; dst < numLoops; dst++) {
111+
if (itGraph[src][dst] && --inDegree[dst] == 0) {
112+
if (iterTypes[dst] == utils::IteratorType::reduction)
113+
redIt.push_back(dst);
114+
else
115+
parIt.push_back(dst);
116+
}
117+
}
118+
}
119+
120+
if (loopOrder.size() == numLoops)
121+
return AffineMap::getPermutationMap(loopOrder, out.getContext());
122+
123+
// Cycle detected.
124+
return AffineMap();
125+
}
126+
127+
IterationGraphSorter
128+
IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
129+
// Must be a demapped sparse kernel.
130+
assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
131+
hasAnySparseOperandOrResult(genericOp) &&
132+
genericOp.getNumDpsInits() == 1);
133+
134+
SmallVector<AffineMap> loopMap = genericOp.getIndexingMapsArray();
135+
SmallVector<Value> ins = genericOp.getDpsInputs();
136+
137+
AffineMap outMap = loopMap.back();
138+
loopMap.pop_back();
139+
140+
Value out = genericOp.getDpsInitOperand(0)->get();
141+
SmallVector<utils::IteratorType> iterTypes =
142+
genericOp.getIteratorTypesArray();
143+
144+
return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap,
145+
std::move(iterTypes));
146+
}
147+
148+
IterationGraphSorter::IterationGraphSorter(
149+
SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
150+
AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes)
151+
: ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
152+
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
153+
// One map per tensor.
154+
assert(loop2InsLvl.size() == ins.size());
155+
// All the affine maps have the same number of dimensions (loops).
156+
assert(llvm::all_equal(llvm::map_range(
157+
loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
158+
// The number of results of the map should match the rank of the tensor.
159+
assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
160+
auto [m, v] = mvPair;
161+
return m.getNumResults() ==
162+
v.getType().template cast<ShapedType>().getRank();
163+
}));
164+
165+
itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
166+
inDegree.resize(getNumLoops());
167+
}
168+
169+
AffineMap IterationGraphSorter::sort(SortMask mask, Value ignored) {
170+
// Reset the interation graph.
171+
for (auto &row : itGraph)
172+
std::fill(row.begin(), row.end(), false);
173+
// Reset cached in-degree.
174+
std::fill(inDegree.begin(), inDegree.end(), 0);
175+
176+
for (auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
177+
// Get map and encoding.
178+
const auto enc = getSparseTensorEncoding(in.getType());
179+
// Skip dense inputs when not requested.
180+
if ((!enc && !includesDenseInput(mask)) || in == ignored)
181+
continue;
182+
183+
addConstraints(in, map);
184+
}
185+
186+
// Get map and encoding.
187+
const auto enc = getSparseTensorEncoding(out.getType());
188+
if ((enc || includesDenseOutput(mask)) && out != ignored)
189+
addConstraints(out, loop2OutLvl);
190+
191+
return topoSort();
192+
}
193+
194+
void IterationGraphSorter::addConstraints(Value t, AffineMap loop2LvlMap) {
195+
auto addIterOrdering = [this](unsigned f, unsigned t) {
196+
if (!itGraph[f][t] && f != t) {
197+
itGraph[f][t] = true;
198+
inDegree[t]++;
199+
}
200+
};
201+
202+
AffineDimFinder finder(iterTypes);
203+
finder.setPickedIterType(utils::IteratorType::reduction);
204+
205+
// To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
206+
// we require there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
207+
// and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
208+
const Level lvlRank = loop2LvlMap.getNumResults();
209+
for (Level lvl = 1; lvl < lvlRank; lvl++) {
210+
const AffineExpr fa = loop2LvlMap.getResult(lvl - 1);
211+
const AffineExpr ta = loop2LvlMap.getResult(lvl);
212+
213+
if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
214+
// Special case when at least one loop2LvlExp is a simple AffineDimExpr
215+
// (say, d0) and we require d0 > {d1, d2, ...} or {d1, d2, ...} > d0
216+
AffineDimCollector fCollector;
217+
fCollector.walkPostOrder(fa);
218+
AffineDimCollector tCollector;
219+
tCollector.walkPostOrder(ta);
220+
221+
for (auto fd : fCollector.dims) {
222+
for (auto td : tCollector.dims) {
223+
const unsigned f = fd.getPosition();
224+
const unsigned t = td.getPosition();
225+
addIterOrdering(f, t);
226+
}
227+
}
228+
continue;
229+
}
230+
231+
// When both loop2LvlExpr is compound, we pick an abitrary reduction loop
232+
// from lhs and rhs and use them as d_x and d_y.
233+
finder.walkPostOrder(fa);
234+
const AffineDimExpr fexp = finder.getDimExpr();
235+
const unsigned fldx = fexp.getPosition();
236+
237+
finder.walkPostOrder(ta);
238+
const AffineDimExpr texp = finder.getDimExpr();
239+
const unsigned tldx = texp.getPosition();
240+
241+
// d_x > d_y
242+
addIterOrdering(fldx, tldx);
243+
244+
AffineDimCollector fCollector;
245+
fCollector.walkPostOrder(fa);
246+
AffineDimCollector tCollector;
247+
tCollector.walkPostOrder(ta);
248+
249+
// Make sure dx and dy is the last.
250+
for (auto fd : fCollector.dims) {
251+
const unsigned f = fd.getPosition();
252+
addIterOrdering(f, fldx);
253+
}
254+
for (auto td : tCollector.dims) {
255+
const unsigned t = td.getPosition();
256+
addIterOrdering(t, tldx);
257+
}
258+
// {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
259+
// This is to ensure that the affine expressions are reduced in sparse
260+
// tensor level ordering.
261+
for (auto fd : fCollector.dims) {
262+
const unsigned f = fd.getPosition();
263+
if (f == fldx) // skip d_x
264+
continue;
265+
for (auto td : tCollector.dims) {
266+
const unsigned t = td.getPosition();
267+
if (t == tldx) // skip d_y
268+
continue;
269+
addIterOrdering(f, t);
270+
}
271+
}
272+
}
273+
}

0 commit comments

Comments
 (0)