Skip to content

Commit e568d00

Browse files
committed
[mlir][sparse] minor code layout edits
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D140934
1 parent 4bbcbda commit e568d00

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
using namespace mlir;
1515
using namespace sparse_tensor;
1616

17+
namespace {
18+
19+
//===----------------------------------------------------------------------===//
20+
// Helper methods.
21+
//===----------------------------------------------------------------------===//
22+
1723
static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
1824
MLIRContext *ctx = tp.getContext();
1925
auto enc = tp.getEncoding();
@@ -34,10 +40,9 @@ static Type convertSpecifier(StorageSpecifierType tp) {
3440
getSpecifierFields(tp));
3541
}
3642

37-
StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
38-
addConversion([](Type type) { return type; });
39-
addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
40-
}
43+
//===----------------------------------------------------------------------===//
44+
// Specifier struct builder.
45+
//===----------------------------------------------------------------------===//
4146

4247
constexpr uint64_t kDimSizePosInSpecifier = 0;
4348
constexpr uint64_t kMemSizePosInSpecifier = 1;
@@ -102,6 +107,21 @@ void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
102107
loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
103108
}
104109

110+
} // namespace
111+
112+
//===----------------------------------------------------------------------===//
113+
// The sparse storage specifier type converter (defined in Passes.h).
114+
//===----------------------------------------------------------------------===//
115+
116+
StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
117+
addConversion([](Type type) { return type; });
118+
addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
119+
}
120+
121+
//===----------------------------------------------------------------------===//
122+
// Storage specifier conversion rules.
123+
//===----------------------------------------------------------------------===//
124+
105125
template <typename Base, typename SourceOp>
106126
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
107127
public:
@@ -176,6 +196,10 @@ struct StorageSpecifierInitOpConverter
176196
}
177197
};
178198

199+
//===----------------------------------------------------------------------===//
200+
// Public method for populating conversion rules.
201+
//===----------------------------------------------------------------------===//
202+
179203
void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
180204
RewritePatternSet &patterns) {
181205
patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
89
#include "SparseTensorStorageLayout.h"
910
#include "CodegenUtils.h"
1011

1112
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1213
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
1315
#include "mlir/Transforms/DialectConversion.h"
1416

1517
using namespace mlir;
1618
using namespace sparse_tensor;
1719

20+
//===----------------------------------------------------------------------===//
21+
// Private helper methods.
22+
//===----------------------------------------------------------------------===//
23+
1824
static Value createIndexCast(OpBuilder &builder, Location loc, Value value,
1925
Type to) {
2026
if (value.getType() != to)
@@ -47,6 +53,10 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
4753
return success();
4854
}
4955

56+
//===----------------------------------------------------------------------===//
57+
// The sparse tensor type converter (defined in Passes.h).
58+
//===----------------------------------------------------------------------===//
59+
5060
SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
5161
addConversion([](Type type) { return type; });
5262
addConversion([&](RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
@@ -65,6 +75,10 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
6575
});
6676
}
6777

78+
//===----------------------------------------------------------------------===//
79+
// StorageLayout methods.
80+
//===----------------------------------------------------------------------===//
81+
6882
unsigned StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind,
6983
std::optional<unsigned> dim) const {
7084
unsigned fieldIdx = -1u;
@@ -89,6 +103,10 @@ unsigned StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind,
89103
return getMemRefFieldIndex(toFieldKind(kind), dim);
90104
}
91105

106+
//===----------------------------------------------------------------------===//
107+
// StorageTensorSpecifier methods.
108+
//===----------------------------------------------------------------------===//
109+
92110
Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc,
93111
RankedTensorType rtp) {
94112
return builder.create<StorageSpecifierInitOp>(
@@ -114,6 +132,10 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
114132
createIndexCast(builder, loc, v, getFieldType(kind, dim)));
115133
}
116134

135+
//===----------------------------------------------------------------------===//
136+
// Public methods.
137+
//===----------------------------------------------------------------------===//
138+
117139
constexpr uint64_t kDataFieldStartingIdx = 0;
118140

119141
void sparse_tensor::foreachFieldInSparseTensor(

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace sparse_tensor {
5050
// };
5151
//
5252
//===----------------------------------------------------------------------===//
53+
5354
enum class SparseTensorFieldKind : uint32_t {
5455
StorageSpec = 0,
5556
PtrMemRef = 1,
@@ -355,4 +356,5 @@ getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
355356

356357
} // namespace sparse_tensor
357358
} // namespace mlir
359+
358360
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_

0 commit comments

Comments
 (0)