Skip to content

Commit bd9c511

Browse files
authored
[mlir][tosa] Add error_if checks for Transpose (#135219)
This adds missing error_if checking for Transpose Op also moved all transpose op's verifier tests from invalid.mlir to verifier.mlir Signed-off-by: Tai Ly <[email protected]>
1 parent b581bd3 commit bd9c511

File tree

3 files changed

+155
-131
lines changed

3 files changed

+155
-131
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,23 +1981,28 @@ LogicalResult tosa::TransposeOp::verify() {
19811981
.failed()) {
19821982
return failure();
19831983
}
1984-
TensorType inputType = getInput1().getType();
1985-
TensorType outputType = getOutput().getType();
1984+
1985+
const ShapeAdaptor inputShape(getInput1().getType());
1986+
const ShapeAdaptor outputShape(getOutput().getType());
1987+
19861988
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
19871989

1988-
if (inputType.hasRank() &&
1989-
constantPerms.size() != static_cast<size_t>(inputType.getRank()))
1990+
if (inputShape.hasRank() &&
1991+
constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
19901992
return emitOpError() << "expected perms attribute to have size "
1991-
<< inputType.getRank() << " (input rank) but got size "
1993+
<< inputShape.getRank()
1994+
<< " (input rank) but got size "
19921995
<< constantPerms.size();
1993-
if (inputType.hasRank() && outputType.hasRank() &&
1994-
inputType.getRank() != outputType.getRank())
1996+
1997+
if (inputShape.hasRank() && outputShape.hasRank() &&
1998+
inputShape.getRank() != outputShape.getRank())
19951999
return emitOpError()
19962000
<< "expected input tensor rank to equal result tensor rank";
1997-
if (outputType.hasRank() &&
1998-
constantPerms.size() != static_cast<size_t>(outputType.getRank()))
2001+
2002+
if (outputShape.hasRank() &&
2003+
constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
19992004
return emitOpError() << "expected perms attribute to have size "
2000-
<< outputType.getRank()
2005+
<< outputShape.getRank()
20012006
<< " (output rank) but got size "
20022007
<< constantPerms.size();
20032008

@@ -2010,22 +2015,27 @@ LogicalResult tosa::TransposeOp::verify() {
20102015
constantPerms, [](int32_t v) -> int64_t { return v; }))))
20112016
return emitOpError() << "expected valid permutation indices";
20122017

2018+
// ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2019+
if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2020+
inputShape.getNumElements() != outputShape.getNumElements())
2021+
return emitOpError() << "expected input1 and output to have same numbers "
2022+
"of elements, got "
2023+
<< inputShape.getNumElements() << " and "
2024+
<< outputShape.getNumElements();
2025+
20132026
// Verify that the types of the input and output tensors are properly
20142027
// permuted.
2015-
if (inputType.hasRank() && outputType.hasRank()) {
2016-
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
2017-
inputType.getRank() == outputType.getRank());
2018-
2019-
for (auto i = 0; i < outputType.getRank(); i++) {
2020-
if (inputType.isDynamicDim(constantPerms[i]) ||
2021-
outputType.isDynamicDim(i))
2028+
if (inputShape.hasRank() && outputShape.hasRank()) {
2029+
for (auto i = 0; i < outputShape.getRank(); i++) {
2030+
if (inputShape.isDynamicDim(constantPerms[i]) ||
2031+
outputShape.isDynamicDim(i))
20222032
continue;
20232033

2024-
if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
2034+
if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
20252035
return emitOpError()
20262036
<< "expected output tensor dim " << i << " to match "
20272037
<< "input dim " << constantPerms[i] << " with value of "
2028-
<< inputType.getDimSize(constantPerms[i]);
2038+
<< inputShape.getDimSize(constantPerms[i]);
20292039
}
20302040
}
20312041

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 0 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -368,79 +368,6 @@ func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor
368368

369369
// -----
370370

371-
func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
372-
// expected-error@+1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
373-
%0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32>
374-
return %0 : tensor<3x13x21x1xf32>
375-
}
376-
377-
// -----
378-
379-
func.func @test_transpose_rank0_perms() {
380-
%14 = tensor.empty() : tensor<5x27xi64>
381-
// expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}}
382-
%72 = tosa.transpose %14 {perms = array<i32> }: (tensor<5x27xi64>) -> tensor<?x?xi64>
383-
return
384-
}
385-
386-
// -----
387-
388-
func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
389-
// expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}}
390-
%0 = tosa.transpose %arg0 {perms = array<i32: 6, 5, 4, 3, 2, 1, 0> }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
391-
return %0 : tensor<3x13x21xf32>
392-
}
393-
394-
// -----
395-
396-
func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
397-
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
398-
%0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 0> }: (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
399-
return %0 : tensor<?x?x?xf32>
400-
}
401-
402-
// -----
403-
404-
func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
405-
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
406-
%1 = tosa.transpose %arg0 {perms = array<i32: -1, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
407-
return %1 : tensor<*xi32>
408-
}
409-
410-
// -----
411-
412-
func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
413-
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
414-
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
415-
return %1 : tensor<*xi32>
416-
}
417-
418-
// -----
419-
420-
func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
421-
// expected-error@+1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
422-
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x4xi32>
423-
return %1 : tensor<3x4xi32>
424-
}
425-
426-
// -----
427-
428-
func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
429-
// expected-error@+1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
430-
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<2x?xi32>) -> tensor<3x4xi32>
431-
return %1 : tensor<3x4xi32>
432-
}
433-
434-
// -----
435-
436-
func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
437-
// expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
438-
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>} : (tensor<2x3xi32>) -> tensor<3x2xf32>
439-
return %1 : tensor<3x2xf32>
440-
}
441-
442-
// -----
443-
444371
func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
445372
// expected-error@+2 {{failed to infer returned types}}
446373
// expected-error@+1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
@@ -783,37 +710,6 @@ func.func @test_tile_io_rank_mismatch() {
783710
return
784711
}
785712

786-
// -----
787-
788-
// CHECK-LABEL: @test_invalid_constant_permutation
789-
func.func @test_invalid_constant_permutation() {
790-
%0 = tensor.empty() : tensor<3x4x5xi32>
791-
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
792-
%2 = tosa.transpose %0 {perms = array<i32: 3, 0, 1>}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32>
793-
return
794-
}
795-
796-
// -----
797-
798-
// CHECK-LABEL: test_rank_size_constant_permutation
799-
func.func @test_rank_size_constant_permutation() {
800-
%0 = arith.constant 6 : index
801-
%2 = tensor.empty(%0) : tensor<?x27xi64>
802-
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
803-
%3 = tosa.transpose %2 {perms = array<i32: 0, 2>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
804-
return
805-
}
806-
807-
// -----
808-
809-
// CHECK-LABEL: test_large_constant_permutation
810-
func.func @test_large_constant_permutation() {
811-
%0 = arith.constant 6 : index
812-
%2 = tensor.empty(%0) : tensor<?x27xi64>
813-
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
814-
%3 = tosa.transpose %2 {perms = array<i32: 1185677355, 332462212>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
815-
return
816-
}
817713

818714
// -----
819715

@@ -2061,14 +1957,6 @@ func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
20611957

20621958
// -----
20631959

2064-
func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
2065-
// expected-error@+1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
2066-
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
2067-
return %1 : tensor<f32>
2068-
}
2069-
2070-
// -----
2071-
20721960
// CHECK-LABEL: test_add_i1
20731961
func.func @test_add_i1(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
20741962
// expected-error@+1 {{'tosa.add' op illegal: operand/result data types not supported}}

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
//--------------------------------------------------------------------------------------------------
2+
// Test expected errors generated by verifier checks.
3+
//--------------------------------------------------------------------------------------------------
4+
5+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
6+
7+
// -----
8+
9+
func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
10+
// expected-error@+1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
11+
%0 = tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>}: (tensor<13x21x3xf32>) -> tensor<3x13x21x1xf32>
12+
return %0 : tensor<3x13x21x1xf32>
13+
}
14+
15+
// -----
16+
17+
func.func @test_transpose_rank0_perms() {
18+
%14 = tensor.empty() : tensor<5x27xi64>
19+
// expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 2 (input rank) but got size 0}}
20+
%72 = tosa.transpose %14 {perms = array<i32> }: (tensor<5x27xi64>) -> tensor<?x?xi64>
21+
return
22+
}
23+
24+
// -----
25+
26+
func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
27+
// expected-error@+1 {{'tosa.transpose' op expected perms attribute to have size 3 (input rank) but got size 7}}
28+
%0 = tosa.transpose %arg0 {perms = array<i32: 6, 5, 4, 3, 2, 1, 0> }: (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
29+
return %0 : tensor<3x13x21xf32>
30+
}
31+
32+
// -----
33+
34+
func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
35+
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
36+
%0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 0> }: (tensor<13x21x3xf32>) -> tensor<?x?x?xf32>
37+
return %0 : tensor<?x?x?xf32>
38+
}
39+
40+
// -----
41+
42+
func.func @test_transpose_invalid_permutation_negative(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
43+
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
44+
%1 = tosa.transpose %arg0 {perms = array<i32: -1, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
45+
return %1 : tensor<*xi32>
46+
}
47+
48+
// -----
49+
50+
func.func @test_transpose_invalid_permutation_tensor_above_range(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
51+
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
52+
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0> }: (tensor<3x2xi32>) -> tensor<*xi32>
53+
return %1 : tensor<*xi32>
54+
}
55+
56+
// -----
57+
58+
func.func @test_transpose_invalid_num_elements(%arg0: tensor<3x2xi32>) -> tensor<3x4xi32> {
59+
// expected-error@+1 {{'tosa.transpose' op expected input1 and output to have same numbers of elements, got 6 and 12}}
60+
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x4xi32>
61+
return %1 : tensor<3x4xi32>
62+
}
63+
64+
// -----
65+
66+
func.func @test_transpose_invalid_permutation_types(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
67+
// expected-error@+1 {{'tosa.transpose' op expected output tensor dim 0 to match input dim 1 with value of 2}}
68+
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<3x2xi32>) -> tensor<3x2xi32>
69+
return %1 : tensor<3x2xi32>
70+
}
71+
72+
// -----
73+
74+
func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor<2x?xi32>) -> tensor<3x4xi32> {
75+
// expected-error@+1 {{'tosa.transpose' op expected output tensor dim 1 to match input dim 0 with value of 2}}
76+
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0> }: (tensor<2x?xi32>) -> tensor<3x4xi32>
77+
return %1 : tensor<3x4xi32>
78+
}
79+
80+
// -----
81+
82+
func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
83+
// expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
84+
%1 = tosa.transpose %arg0 {perms = array<i32: 1, 0>} : (tensor<2x3xi32>) -> tensor<3x2xf32>
85+
return %1 : tensor<3x2xf32>
86+
}
87+
88+
// -----
89+
90+
// CHECK-LABEL: @test_invalid_constant_permutation
91+
func.func @test_invalid_constant_permutation() {
92+
%0 = tensor.empty() : tensor<3x4x5xi32>
93+
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
94+
%2 = tosa.transpose %0 {perms = array<i32: 3, 0, 1>}: (tensor<3x4x5xi32>) -> tensor<3x4x5xi32>
95+
return
96+
}
97+
98+
// -----
99+
100+
// CHECK-LABEL: test_rank_size_constant_permutation
101+
func.func @test_rank_size_constant_permutation() {
102+
%0 = arith.constant 6 : index
103+
%2 = tensor.empty(%0) : tensor<?x27xi64>
104+
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
105+
%3 = tosa.transpose %2 {perms = array<i32: 0, 2>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
106+
return
107+
}
108+
109+
// -----
110+
111+
// CHECK-LABEL: test_large_constant_permutation
112+
func.func @test_large_constant_permutation() {
113+
%0 = arith.constant 6 : index
114+
%2 = tensor.empty(%0) : tensor<?x27xi64>
115+
// expected-error@+1 {{'tosa.transpose' op expected valid permutation indices}}
116+
%3 = tosa.transpose %2 {perms = array<i32: 1185677355, 332462212>}: (tensor<?x27xi64>) -> tensor<?x27xi64>
117+
return
118+
}
119+
120+
// -----
121+
122+
func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
123+
// expected-error@+1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
124+
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
125+
return %1 : tensor<f32>
126+
}

0 commit comments

Comments
 (0)