Skip to content

Commit f248d0b

Browse files
authored
[mlir][sparse] implement sparse_tensor.reorder_coo (llvm#68916)
As a side effect of the change, it also unifies the convertOp implementation between lib/codegen path.
1 parent 220244b commit f248d0b

22 files changed

+265
-1345
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ enum class Action : uint32_t {
151151
kToCOO = 5,
152152
kToIterator = 6,
153153
kPack = 7,
154+
// Sort an unordered COO in place.
155+
kSortCOOInPlace = 8,
154156
};
155157

156158
/// This enum defines all the sparse representations supportable by

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,6 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
200200
// Whether the convert can be done by a single step (either a sort or a foreach),
201201
// or it would require a tmp buffer (sort, then foreach).
202202
bool directConvertable();
203-
204-
// Whether the convert is actually a sort coo
205-
// TODO: The method will be removed when sort_coo operation is introduced.
206-
bool isSortCOOConvert();
207203
}];
208204

209205
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";

mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ struct SparseCompilerOptions
8888
*this, "enable-buffer-initialization",
8989
desc("Enable zero-initialization of memory buffers"), init(false)};
9090

91+
// TODO: Delete the option, it should also be false after switching to
92+
// buffer-deallocation-pass
9193
PassOptions::Option<bool> createSparseDeallocs{
9294
*this, "create-sparse-deallocs",
9395
desc("Specify if the temporary buffers created by the sparse "
@@ -100,11 +102,6 @@ struct SparseCompilerOptions
100102
*this, "vl", desc("Set the vector length (0 disables vectorization)"),
101103
init(0)};
102104

103-
// These options must be kept in sync with `SparseTensorConversionBase`.
104-
PassOptions::Option<int32_t> sparseToSparse{
105-
*this, "s2s-strategy",
106-
desc("Set the strategy for sparse-to-sparse conversion"), init(0)};
107-
108105
// These options must be kept in sync with the `ConvertVectorToLLVM`
109106
// (defined in include/mlir/Dialect/SparseTensor/Pipelines/Passes.h).
110107
PassOptions::Option<bool> reassociateFPReductions{
@@ -174,12 +171,6 @@ struct SparseCompilerOptions
174171
enableRuntimeLibrary);
175172
}
176173

177-
/// Projects out the options for `createSparseTensorConversionPass`.
178-
SparseTensorConversionOptions sparseTensorConversionOptions() const {
179-
return SparseTensorConversionOptions(
180-
sparseToSparseConversionStrategy(sparseToSparse));
181-
}
182-
183174
/// Projects out the options for `createConvertVectorToLLVMPass`.
184175
ConvertVectorToLLVMPassOptions lowerVectorToLLVMOptions() const {
185176
ConvertVectorToLLVMPassOptions opts{};

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -119,37 +119,11 @@ class SparseTensorTypeToPtrConverter : public TypeConverter {
119119
SparseTensorTypeToPtrConverter();
120120
};
121121

122-
/// Defines a strategy for implementing sparse-to-sparse conversion.
123-
/// `kAuto` leaves it up to the compiler to automatically determine
124-
/// the method used. `kViaCOO` converts the source tensor to COO and
125-
/// then converts the COO to the target format. `kDirect` converts
126-
/// directly via the algorithm in <https://arxiv.org/abs/2001.02609>;
127-
/// however, beware that there are many formats not supported by this
128-
/// conversion method.
129-
enum class SparseToSparseConversionStrategy { kAuto, kViaCOO, kDirect };
130-
131-
/// Converts command-line sparse2sparse flag to the strategy enum.
132-
SparseToSparseConversionStrategy sparseToSparseConversionStrategy(int32_t flag);
133-
134-
/// SparseTensorConversion options.
135-
struct SparseTensorConversionOptions {
136-
SparseTensorConversionOptions(SparseToSparseConversionStrategy s2s)
137-
: sparseToSparseStrategy(s2s) {}
138-
SparseTensorConversionOptions()
139-
: SparseTensorConversionOptions(SparseToSparseConversionStrategy::kAuto) {
140-
}
141-
SparseToSparseConversionStrategy sparseToSparseStrategy;
142-
};
143-
144122
/// Sets up sparse tensor conversion rules.
145-
void populateSparseTensorConversionPatterns(
146-
TypeConverter &typeConverter, RewritePatternSet &patterns,
147-
const SparseTensorConversionOptions &options =
148-
SparseTensorConversionOptions());
123+
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
124+
RewritePatternSet &patterns);
149125

150126
std::unique_ptr<Pass> createSparseTensorConversionPass();
151-
std::unique_ptr<Pass>
152-
createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
153127

154128
//===----------------------------------------------------------------------===//
155129
// The SparseTensorCodegen pass.
@@ -235,7 +209,6 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass();
235209
std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
236210
const bufferization::OneShotBufferizationOptions &bufferizationOptions,
237211
const SparsificationOptions &sparsificationOptions,
238-
const SparseTensorConversionOptions &sparseTensorConversionOptions,
239212
bool createSparseDeallocs, bool enableRuntimeLibrary,
240213
bool enableBufferInitialization, unsigned vectorLength,
241214
bool enableVLAVectorization, bool enableSIMDIndex32);

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
201201
"scf::SCFDialect",
202202
"sparse_tensor::SparseTensorDialect",
203203
];
204-
let options = [
205-
Option<"sparseToSparse", "s2s-strategy", "int32_t", "0",
206-
"Set the strategy for sparse-to-sparse conversion">,
207-
];
208204
}
209205

210206
def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,19 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
374374
/// Partially specialize lexicographical insertions based on template types.
375375
void lexInsert(const uint64_t *lvlCoords, V val) final {
376376
assert(lvlCoords && "Received nullptr for level-coordinates");
377+
// TODO: get rid of this! canonicalize all-dense "sparse" array into dense
378+
// tensors.
379+
bool allDense = std::all_of(getLvlTypes().begin(), getLvlTypes().end(),
380+
[](DimLevelType lt) { return isDenseDLT(lt); });
381+
if (allDense) {
382+
uint64_t lvlRank = getLvlRank();
383+
uint64_t valIdx = 0;
384+
// Linearize the address
385+
for (size_t lvl = 0; lvl < lvlRank; lvl++)
386+
valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
387+
values[valIdx] = val;
388+
return;
389+
}
377390
// First, wrap up pending insertion path.
378391
uint64_t diffLvl = 0;
379392
uint64_t full = 0;
@@ -457,6 +470,63 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
457470
return coo;
458471
}
459472

473+
/// Sort the unordered tensor in place, the method assumes that it is
474+
/// an unordered COO tensor.
475+
void sortInPlace() {
476+
uint64_t nnz = values.size();
477+
#ifndef NDEBUG
478+
for (uint64_t l = 0; l < getLvlRank(); l++)
479+
assert(nnz == coordinates[l].size());
480+
#endif
481+
482+
// In-place permutation.
483+
auto applyPerm = [this](std::vector<uint64_t> &perm) {
484+
size_t length = perm.size();
485+
size_t lvlRank = getLvlRank();
486+
// Cache for the current level coordinates.
487+
std::vector<P> lvlCrds(lvlRank);
488+
for (size_t i = 0; i < length; i++) {
489+
size_t current = i;
490+
if (i != perm[current]) {
491+
for (size_t l = 0; l < lvlRank; l++)
492+
lvlCrds[l] = coordinates[l][i];
493+
V val = values[i];
494+
// Deals with a permutation cycle.
495+
while (i != perm[current]) {
496+
size_t next = perm[current];
497+
// Swaps the level coordinates and value.
498+
for (size_t l = 0; l < lvlRank; l++)
499+
coordinates[l][current] = coordinates[l][next];
500+
values[current] = values[next];
501+
perm[current] = current;
502+
current = next;
503+
}
504+
for (size_t l = 0; l < lvlRank; l++)
505+
coordinates[l][current] = lvlCrds[l];
506+
values[current] = val;
507+
perm[current] = current;
508+
}
509+
}
510+
};
511+
512+
std::vector<uint64_t> sortedIdx(nnz, 0);
513+
for (uint64_t i = 0; i < nnz; i++)
514+
sortedIdx[i] = i;
515+
516+
std::sort(sortedIdx.begin(), sortedIdx.end(),
517+
[this](uint64_t lhs, uint64_t rhs) {
518+
for (uint64_t l = 0; l < getLvlRank(); l++) {
519+
if (coordinates[l][lhs] == coordinates[l][rhs])
520+
continue;
521+
return coordinates[l][lhs] < coordinates[l][rhs];
522+
}
523+
assert(false && "duplicate coordinates");
524+
return false;
525+
});
526+
527+
applyPerm(sortedIdx);
528+
}
529+
460530
private:
461531
/// Appends an arbitrary new position to `positions[lvl]`. This method
462532
/// checks that `pos` is representable in the `P` type; however, it

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,20 +1060,12 @@ LogicalResult ConvertOp::verify() {
10601060
}
10611061

10621062
OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1063-
Type dstType = getType();
1064-
// Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse
1065-
// convert for codegen to remove. This is because we use trivial
1066-
// sparse-to-sparse convert to tell bufferization that the sparse codegen
1067-
// will expand the tensor buffer into sparse tensor storage.
1068-
if (!getSparseTensorEncoding(dstType) && dstType == getSource().getType())
1063+
if (getType() == getSource().getType())
10691064
return getSource();
10701065
return {};
10711066
}
10721067

10731068
bool ConvertOp::directConvertable() {
1074-
if (isSortCOOConvert())
1075-
return false;
1076-
10771069
SparseTensorType srcStt = getSparseTensorType(getSource());
10781070
SparseTensorType dstStt = getSparseTensorType(getDest());
10791071

@@ -1099,15 +1091,6 @@ bool ConvertOp::directConvertable() {
10991091
return false;
11001092
}
11011093

1102-
bool ConvertOp::isSortCOOConvert() {
1103-
// TODO: we should instead use a different sort_coo operation to handle
1104-
// the conversion between COOs (but with different ordering).
1105-
return isUniqueCOOType(getSource().getType()) &&
1106-
isUniqueCOOType(getDest().getType()) &&
1107-
!getSparseTensorType(getSource()).isAllOrdered() &&
1108-
getSparseTensorType(getDest()).isAllOrdered();
1109-
}
1110-
11111094
LogicalResult ToPositionsOp::verify() {
11121095
auto e = getSparseTensorEncoding(getTensor().getType());
11131096
if (failed(lvlIsInBounds(getLevel(), getTensor())))

mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ void mlir::sparse_tensor::buildSparseCompiler(
3535
pm.addPass(createSparsificationAndBufferizationPass(
3636
getBufferizationOptionsForSparsification(
3737
options.testBufferizationAnalysisOnly),
38-
options.sparsificationOptions(), options.sparseTensorConversionOptions(),
39-
options.createSparseDeallocs, options.enableRuntimeLibrary,
40-
options.enableBufferInitialization, options.vectorLength,
38+
options.sparsificationOptions(), options.createSparseDeallocs,
39+
options.enableRuntimeLibrary, options.enableBufferInitialization,
40+
options.vectorLength,
4141
/*enableVLAVectorization=*/options.armSVE,
4242
/*enableSIMDIndex32=*/options.force32BitVectorIndices));
4343
if (options.testBufferizationAnalysisOnly)

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -680,31 +680,26 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
680680
};
681681

682682
// TODO: use a new SortCOO operation here instead of reusing convert op.
683-
struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
683+
struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
684684
using OpConversionPattern::OpConversionPattern;
685685
LogicalResult
686-
matchAndRewrite(ConvertOp op, ConvertOpAdaptor adaptor,
686+
matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
687687
ConversionPatternRewriter &rewriter) const override {
688-
// Direct conversion should have already been lowered.
689-
if (!op.isSortCOOConvert())
690-
return failure();
691-
692688
Location loc = op.getLoc();
693689
MLIRContext *ctx = op.getContext();
694690

695-
SparseTensorType srcStt = getSparseTensorType(op.getSource());
696-
SparseTensorType dstStt = getSparseTensorType(op.getDest());
691+
SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
692+
SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
697693

698-
// TODO: This should be verification rules for sort_coo operation.
694+
// Should have been verified.
699695
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
700696
isUniqueCOOType(srcStt.getRankedTensorType()) &&
701697
isUniqueCOOType(dstStt.getRankedTensorType()));
702-
703698
assert(dstStt.hasSameDimToLvl(srcStt));
704699

705700
// We don't need a mutable descriptor here as we perform sorting in-place.
706-
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getSource());
707-
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
701+
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
702+
auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
708703
auto crd = desc.getAOSMemRef();
709704
auto val = desc.getValMemRef();
710705

@@ -715,12 +710,11 @@ struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
715710
auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
716711

717712
rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
718-
rewriter.getIndexAttr(0),
719-
SparseTensorSortKind::HybridQuickSort);
713+
rewriter.getIndexAttr(0), op.getAlgorithm());
720714

721715
// Since we do in-place sorting, the destinate tensor will have the same set
722716
// of memrefs as the source tensor.
723-
rewriter.replaceOp(op, adaptor.getSource());
717+
rewriter.replaceOp(op, adaptor.getInputCoo());
724718
return success();
725719
}
726720
};
@@ -1147,9 +1141,6 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
11471141
LogicalResult
11481142
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
11491143
ConversionPatternRewriter &rewriter) const override {
1150-
if (op.isSortCOOConvert())
1151-
return failure();
1152-
11531144
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
11541145
SparseTensorEncodingAttr encSrc =
11551146
getSparseTensorEncoding(op.getSource().getType());
@@ -1603,7 +1594,7 @@ void mlir::populateSparseTensorCodegenPatterns(
16031594
SparseCastConverter, SparseExtractSliceConverter,
16041595
SparseTensorLoadConverter, SparseExpandConverter,
16051596
SparseCompressConverter, SparseInsertConverter,
1606-
SparseSortCOOConverter,
1597+
SparseReorderCOOConverter,
16071598
SparseSliceGetterOpConverter<ToSliceOffsetOp,
16081599
StorageSpecifierKind::DimOffset>,
16091600
SparseSliceGetterOpConverter<ToSliceStrideOp,

0 commit comments

Comments
 (0)