Skip to content

Commit bb44a6b

Browse files
committed
[mlir][sparse] migrate more to new surface syntax
Replaced the "NEW_SYNTAX" with the more readable "map" (which we may, or may not keep). Minor improvement in keyword parsing, migrated a few more examples over. Reviewed By: Peiming, yinying-lisa-li Differential Revision: https://reviews.llvm.org/D158325
1 parent 58fe7b7 commit bb44a6b

File tree

3 files changed

+57
-95
lines changed

3 files changed

+57
-95
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -451,24 +451,22 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
451451
AffineMap dimToLvl = {};
452452
unsigned posWidth = 0;
453453
unsigned crdWidth = 0;
454-
455454
StringRef attrName;
456-
// Exactly 6 keys.
457455
SmallVector<StringRef, 6> keys = {"lvlTypes", "dimToLvl", "posWidth",
458-
"crdWidth", "dimSlices", "NEW_SYNTAX"};
456+
"crdWidth", "dimSlices", "map"};
459457
while (succeeded(parser.parseOptionalKeyword(&attrName))) {
460-
if (!llvm::is_contained(keys, attrName)) {
458+
// Detect admissible keyword.
459+
auto *it = find(keys, attrName);
460+
if (it == keys.end()) {
461461
parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
462462
return {};
463463
}
464-
464+
unsigned keyWordIndex = it - keys.begin();
465465
// Consume the `=` after keys
466466
RETURN_ON_FAIL(parser.parseEqual())
467-
// FIXME: using `operator==` below duplicates the string comparison
468-
// cost of the `is_contained` check above. Should instead use some
469-
// "find" function that returns the index into `keys` so that we can
470-
// dispatch on that instead.
471-
if (attrName == "lvlTypes") {
467+
// Dispatch on keyword.
468+
switch (keyWordIndex) {
469+
case 0: { // lvlTypes
472470
Attribute attr;
473471
RETURN_ON_FAIL(parser.parseAttribute(attr));
474472
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr);
@@ -485,25 +483,33 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
485483
return {};
486484
}
487485
}
488-
} else if (attrName == "dimToLvl") {
486+
break;
487+
}
488+
case 1: { // dimToLvl
489489
Attribute attr;
490490
RETURN_ON_FAIL(parser.parseAttribute(attr))
491491
auto affineAttr = llvm::dyn_cast<AffineMapAttr>(attr);
492492
ERROR_IF(!affineAttr, "expected an affine map for dimToLvl")
493493
dimToLvl = affineAttr.getValue();
494-
} else if (attrName == "posWidth") {
494+
break;
495+
}
496+
case 2: { // posWidth
495497
Attribute attr;
496498
RETURN_ON_FAIL(parser.parseAttribute(attr))
497499
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
498500
ERROR_IF(!intAttr, "expected an integral position bitwidth")
499501
posWidth = intAttr.getInt();
500-
} else if (attrName == "crdWidth") {
502+
break;
503+
}
504+
case 3: { // crdWidth
501505
Attribute attr;
502506
RETURN_ON_FAIL(parser.parseAttribute(attr))
503507
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
504508
ERROR_IF(!intAttr, "expected an integral index bitwidth")
505509
crdWidth = intAttr.getInt();
506-
} else if (attrName == "dimSlices") {
510+
break;
511+
}
512+
case 4: { // dimSlices
507513
RETURN_ON_FAIL(parser.parseLSquare())
508514
// Dispatches to DimSliceAttr to skip mnemonic
509515
bool finished = false;
@@ -519,26 +525,22 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
519525
if (!finished)
520526
return {};
521527
RETURN_ON_FAIL(parser.parseRSquare())
522-
} else if (attrName == "NEW_SYNTAX") {
523-
// Note that we are in the process of migrating to a new STEA surface
524-
// syntax. While this is ongoing we use the temporary "NEW_SYNTAX = ...."
525-
// to switch to the new parser. This allows us to gradually migrate
526-
// examples over to the new surface syntax before making the complete
527-
// switch once work is completed.
528-
// TODO: replace everything here with new STEA surface syntax parser
528+
break;
529+
}
530+
case 5: { // map (new STEA surface syntax)
529531
ir_detail::DimLvlMapParser cParser(parser);
530532
auto res = cParser.parseDimLvlMap();
531533
RETURN_ON_FAIL(res);
532534
// TODO: use DimLvlMap directly as storage representation, rather
533535
// than converting things over.
534536
const auto &dlm = *res;
535537

536-
ERROR_IF(!lvlTypes.empty(), "Cannot mix `lvlTypes` with `NEW_SYNTAX`")
538+
ERROR_IF(!lvlTypes.empty(), "Cannot mix `lvlTypes` with `map`")
537539
const Level lvlRank = dlm.getLvlRank();
538540
for (Level lvl = 0; lvl < lvlRank; lvl++)
539541
lvlTypes.push_back(dlm.getLvlType(lvl));
540542

541-
ERROR_IF(!dimSlices.empty(), "Cannot mix `dimSlices` with `NEW_SYNTAX`")
543+
ERROR_IF(!dimSlices.empty(), "Cannot mix `dimSlices` with `map`")
542544
const Dimension dimRank = dlm.getDimRank();
543545
for (Dimension dim = 0; dim < dimRank; dim++)
544546
dimSlices.push_back(dlm.getDimSlice(dim));
@@ -558,11 +560,12 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
558560
dimSlices.clear();
559561
}
560562

561-
ERROR_IF(dimToLvl, "Cannot mix `dimToLvl` with `NEW_SYNTAX`")
563+
ERROR_IF(dimToLvl, "Cannot mix `dimToLvl` with `map`")
562564
dimToLvl = dlm.getDimToLvlMap(parser.getContext());
565+
break;
563566
}
564-
565-
// Only the last item can omit the comma
567+
} // switch
568+
// Only last item can omit the comma.
566569
if (parser.parseOptionalComma().failed())
567570
break;
568571
}

mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -70,47 +70,32 @@ func.func private @tensor_invalid_key(%arg0: tensor<16x32xf32, #a>) -> ()
7070
}>
7171
func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
7272

73-
///////////////////////////////////////////////////////////////////////////////
74-
// Migration plan for new STEA surface syntax,
75-
// use the NEW_SYNTAX on selected examples
76-
// and then TODO: remove when fully migrated
77-
///////////////////////////////////////////////////////////////////////////////
78-
7973
// -----
8074

81-
// expected-error@+3 {{Level-rank mismatch between forward-declarations and specifiers. Declared 3 level-variables; but got 2 level-specifiers.}}
75+
// expected-error@+2 {{Level-rank mismatch between forward-declarations and specifiers. Declared 3 level-variables; but got 2 level-specifiers.}}
8276
#TooManyLvlDecl = #sparse_tensor.encoding<{
83-
NEW_SYNTAX =
84-
{l0, l1, l2} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
77+
map = {l0, l1, l2} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
8578
}>
8679
func.func private @too_many_lvl_decl(%arg0: tensor<?x?xf64, #TooManyLvlDecl>) {
8780
return
8881
}
8982

9083
// -----
9184

92-
// NOTE: We don't get the "level-rank mismatch" error here, because this
93-
// "undeclared identifier" error occurs first. The error message is a bit
94-
// misleading because `parseLvlVarBinding` calls `parseVarUsage` rather
95-
// than `parseVarBinding` (and the error message generated by `parseVar`
96-
// is assuming that `parseVarUsage` is only called for *uses* of variables).
97-
// expected-error@+3 {{use of undeclared identifier 'l1'}}
85+
// expected-error@+2 {{use of undeclared identifier 'l1'}}
9886
#TooFewLvlDecl = #sparse_tensor.encoding<{
99-
NEW_SYNTAX =
100-
{l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
87+
map = {l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
10188
}>
10289
func.func private @too_few_lvl_decl(%arg0: tensor<?x?xf64, #TooFewLvlDecl>) {
10390
return
10491
}
10592

10693
// -----
10794

108-
// expected-error@+3 {{Level-variable ordering mismatch. The variable 'l0' was forward-declared as the 1st level; but is bound by the 0th specification.}}
95+
// expected-error@+2 {{Level-variable ordering mismatch. The variable 'l0' was forward-declared as the 1st level; but is bound by the 0th specification.}}
10996
#WrongOrderLvlDecl = #sparse_tensor.encoding<{
110-
NEW_SYNTAX =
111-
{l1, l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
97+
map = {l1, l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
11298
}>
11399
func.func private @wrong_order_lvl_decl(%arg0: tensor<?x?xf64, #WrongOrderLvlDecl>) {
114100
return
115101
}
116-

mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
// CHECK-LABEL: func private @sparse_1d_tensor(
44
// CHECK-SAME: tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>)
5-
func.func private @sparse_1d_tensor(tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = ["compressed"] }>>)
5+
func.func private @sparse_1d_tensor(tensor<32xf64, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>)
66

77
// -----
88

99
#CSR = #sparse_tensor.encoding<{
10-
lvlTypes = [ "dense", "compressed" ],
11-
dimToLvl = affine_map<(i,j) -> (i,j)>,
10+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
1211
posWidth = 64,
1312
crdWidth = 64
1413
}>
@@ -19,9 +18,20 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
1918

2019
// -----
2120

21+
#CSR_explicit = #sparse_tensor.encoding<{
22+
map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
23+
}>
24+
25+
// CHECK-LABEL: func private @CSR_explicit(
26+
// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
27+
func.func private @CSR_explicit(%arg0: tensor<?x?xf64, #CSR_explicit>) {
28+
return
29+
}
30+
31+
// -----
32+
2233
#CSC = #sparse_tensor.encoding<{
23-
lvlTypes = [ "dense", "compressed" ],
24-
dimToLvl = affine_map<(i,j) -> (j,i)>,
34+
map = (d0, d1) -> (d1 : dense, d0 : compressed),
2535
posWidth = 0,
2636
crdWidth = 0
2737
}>
@@ -33,8 +43,7 @@ func.func private @sparse_csc(tensor<?x?xf32, #CSC>)
3343
// -----
3444

3545
#DCSC = #sparse_tensor.encoding<{
36-
lvlTypes = [ "compressed", "compressed" ],
37-
dimToLvl = affine_map<(i,j) -> (j,i)>,
46+
map = (d0, d1) -> (d1 : compressed, d0 : compressed),
3847
posWidth = 0,
3948
crdWidth = 64
4049
}>
@@ -129,7 +138,6 @@ func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
129138
// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimSlices = [ (1, ?, 1), (?, 4, 2) ] }>>
130139
func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
131140

132-
133141
// -----
134142

135143
// TODO: It is probably better to use [dense, dense, 2:4] (see NV_24 defined using new syntax
@@ -143,60 +151,27 @@ func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
143151
// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed24" ] }>>
144152
func.func private @sparse_2_out_of_4(tensor<?x?xf64, #NV_24>)
145153

146-
///////////////////////////////////////////////////////////////////////////////
147-
// Migration plan for new STEA surface syntax,
148-
// use the NEW_SYNTAX on selected examples
149-
// and then TODO: remove when fully migrated
150-
///////////////////////////////////////////////////////////////////////////////
151-
152154
// -----
153155

154-
#CSR_implicit = #sparse_tensor.encoding<{
155-
NEW_SYNTAX =
156-
(d0, d1) -> (d0 : dense, d1 : compressed)
157-
}>
158-
159-
// CHECK-LABEL: func private @CSR_implicit(
160-
// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
161-
func.func private @CSR_implicit(%arg0: tensor<?x?xf64, #CSR_implicit>) {
162-
return
163-
}
164-
165-
// -----
166-
167-
#CSR_explicit = #sparse_tensor.encoding<{
168-
NEW_SYNTAX =
169-
{l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
170-
}>
171-
172-
// CHECK-LABEL: func private @CSR_explicit(
173-
// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
174-
func.func private @CSR_explicit(%arg0: tensor<?x?xf64, #CSR_explicit>) {
175-
return
176-
}
177-
178-
// -----
179-
180-
#BCSR_implicit = #sparse_tensor.encoding<{
181-
NEW_SYNTAX =
182-
( i, j ) ->
156+
#BCSR = #sparse_tensor.encoding<{
157+
map = ( i, j ) ->
183158
( i floordiv 2 : compressed,
184159
j floordiv 3 : compressed,
185160
i mod 2 : dense,
186161
j mod 3 : dense
187162
)
188163
}>
189164

190-
// CHECK-LABEL: func private @BCSR_implicit(
165+
// CHECK-LABEL: func private @BCSR(
191166
// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ], dimToLvl = affine_map<(d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>>
192-
func.func private @BCSR_implicit(%arg0: tensor<?x?xf64, #BCSR_implicit>) {
167+
func.func private @BCSR(%arg0: tensor<?x?xf64, #BCSR>) {
193168
return
194169
}
195170

196171
// -----
197172

198173
#BCSR_explicit = #sparse_tensor.encoding<{
199-
NEW_SYNTAX =
174+
map =
200175
{il, jl, ii, jj}
201176
( i = il * 2 + ii,
202177
j = jl * 3 + jj
@@ -217,8 +192,7 @@ func.func private @BCSR_explicit(%arg0: tensor<?x?xf64, #BCSR_explicit>) {
217192
// -----
218193

219194
#NV_24 = #sparse_tensor.encoding<{
220-
NEW_SYNTAX =
221-
( i, j ) ->
195+
map = ( i, j ) ->
222196
( i : dense,
223197
j floordiv 4 : dense,
224198
j mod 4 : compressed24

0 commit comments

Comments
 (0)