5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
+
8
9
#include " SparseTensorStorageLayout.h"
9
10
#include " CodegenUtils.h"
10
11
11
12
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
12
13
#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14
+ #include " mlir/Dialect/SparseTensor/Transforms/Passes.h"
13
15
#include " mlir/Transforms/DialectConversion.h"
14
16
15
17
using namespace mlir ;
16
18
using namespace sparse_tensor ;
17
19
20
+ // ===----------------------------------------------------------------------===//
21
+ // Private helper methods.
22
+ // ===----------------------------------------------------------------------===//
23
+
18
24
static Value createIndexCast (OpBuilder &builder, Location loc, Value value,
19
25
Type to) {
20
26
if (value.getType () != to)
@@ -47,6 +53,10 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
47
53
return success ();
48
54
}
49
55
56
+ // ===----------------------------------------------------------------------===//
57
+ // The sparse tensor type converter (defined in Passes.h).
58
+ // ===----------------------------------------------------------------------===//
59
+
50
60
SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter () {
51
61
addConversion ([](Type type) { return type; });
52
62
addConversion ([&](RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
@@ -65,6 +75,10 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
65
75
});
66
76
}
67
77
78
+ // ===----------------------------------------------------------------------===//
79
+ // StorageLayout methods.
80
+ // ===----------------------------------------------------------------------===//
81
+
68
82
unsigned StorageLayout::getMemRefFieldIndex (SparseTensorFieldKind kind,
69
83
std::optional<unsigned > dim) const {
70
84
unsigned fieldIdx = -1u ;
@@ -89,6 +103,10 @@ unsigned StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind,
89
103
return getMemRefFieldIndex (toFieldKind (kind), dim);
90
104
}
91
105
106
+ // ===----------------------------------------------------------------------===//
107
+ // StorageTensorSpecifier methods.
108
+ // ===----------------------------------------------------------------------===//
109
+
92
110
Value SparseTensorSpecifier::getInitValue (OpBuilder &builder, Location loc,
93
111
RankedTensorType rtp) {
94
112
return builder.create <StorageSpecifierInitOp>(
@@ -114,6 +132,10 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
114
132
createIndexCast (builder, loc, v, getFieldType (kind, dim)));
115
133
}
116
134
135
+ // ===----------------------------------------------------------------------===//
136
+ // Public methods.
137
+ // ===----------------------------------------------------------------------===//
138
+
117
139
constexpr uint64_t kDataFieldStartingIdx = 0 ;
118
140
119
141
void sparse_tensor::foreachFieldInSparseTensor (
0 commit comments