Skip to content

Commit e4351f2

Browse files
authored
[TOSA] Don't run validation pass on non TOSA operations (llvm#120205)
This commit ensures the validation pass is not run on operations from other dialects. In doing so, operations from other dialects that, for example, use types not supported by TOSA don't result in an error. Signed-off-by: Luke Hutton <[email protected]>
1 parent 9fa109a commit e4351f2

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@ bool TosaValidation::isValidElementType(Type type) {
543543
void TosaValidation::runOnOperation() {
544544
configLevelAndProfile();
545545
getOperation().walk([&](Operation *op) {
546+
if (!op->getDialect() ||
547+
op->getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
548+
return;
549+
546550
for (Value operand : op->getOperands()) {
547551
auto elementTy = getElementTypeOrSelf(operand);
548552
if (!isValidElementType(elementTy)) {

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,6 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1
625625
func.func @test_unsupported_int64_data_type(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> {
626626
// expected-error@+1 {{'tosa.argmax' op is not profile-aligned: element type 'i64' is not legal}}
627627
%0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64>
628-
// expected-error@+1 {{'func.return' op is not profile-aligned: element type 'i64' is not legal}}
629628
return %0 : tensor<1x13x13xi64>
630629
}
631630

@@ -879,4 +878,13 @@ func.func @test_mismatch_in_out_shape_logical_not(%arg0: tensor<1x21x3xi1>) -> t
879878
// expected-error@+1 {{'tosa.logical_not' op requires the same shape for all operands and results}}
880879
%0 = tosa.logical_not %arg0 : (tensor<1x21x3xi1>) -> tensor<13x21x3xi1>
881880
return %0 : tensor<13x21x3xi1>
882-
}
881+
}
882+
883+
// -----
884+
885+
// Check validate pass doesn't run on non TOSA ops
886+
func.func @test_non_tosa_ops() {
887+
%0 = arith.constant 6 : index
888+
%2 = tensor.empty(%0) : tensor<?x27xi64>
889+
return
890+
}

0 commit comments

Comments
 (0)