Skip to content

Commit 96064e1

Browse files
authored
[mlir][tosa] Add table size check for Table Op (#135262)
Add table size check for Table Op and add lit tests to error_if_check.mlir also corrected some existing tests that violated the table size checks Signed-off-by: Tai Ly <[email protected]>
1 parent bd9c511 commit 96064e1

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,8 +1012,30 @@ bool checkErrorIfMul(Operation *op) {
10121012
return true;
10131013
}
10141014

1015+
bool checkErrorIfTable(Operation *op) {
1016+
auto table = dyn_cast<tosa::TableOp>(op);
1017+
if (!table)
1018+
return true;
1019+
1020+
// REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
1021+
const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
1022+
const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
1023+
1024+
const ShapeAdaptor tableShape(table.getTable().getType());
1025+
if (tableShape.hasStaticShape()) {
1026+
const auto numElements = tableShape.getNumElements();
1027+
if (numElements != tableSize) {
1028+
op->emitOpError() << "requires table size of " << tableSize << ", got "
1029+
<< numElements;
1030+
return false;
1031+
}
1032+
}
1033+
1034+
return true;
1035+
}
1036+
10151037
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1016-
if (!checkErrorIfResize(op) || !checkErrorIfMul(op))
1038+
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
10171039
return failure();
10181040
return success();
10191041
}

mlir/test/Dialect/Tosa/dynamic_extension.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8
1313

1414
// -----
1515

16-
func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
17-
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
16+
func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () {
17+
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8>
1818
return
1919
}
2020

mlir/test/Dialect/Tosa/error_if_check.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,19 @@ func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8
113113
%mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32>
114114
return %mul : tensor<1x8x8x8xi32>
115115
}
116+
117+
// -----
118+
// CHECK-LABEL: test_i16_table_size
119+
func.func @test_i16_table_size(%arg0: tensor<2x64xi16>, %arg1: tensor<256xi16>) -> tensor<2x64xi32> {
120+
// expected-error@+1 {{'tosa.table' op requires table size of 513, got 256}}
121+
%0 = tosa.table %arg0, %arg1 : (tensor<2x64xi16>, tensor<256xi16>) -> tensor<2x64xi32>
122+
return %0 : tensor<2x64xi32>
123+
}
124+
125+
// -----
126+
// CHECK-LABEL: test_i8_table_size
127+
func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) -> tensor<2x64xi8> {
128+
// expected-error@+1 {{'tosa.table' op requires table size of 256, got 513}}
129+
%0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
130+
return %0 : tensor<2x64xi8>
131+
}

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,9 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8
497497

498498
// -----
499499

500-
func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
500+
func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () {
501501
// expected-error@+1 {{'tosa.table' op expected compile time resolvable constant, but got variable value for operand #1}}
502-
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
502+
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8>
503503
return
504504
}
505505

0 commit comments

Comments
 (0)