Skip to content

Commit ec46e03

Browse files
sjarusrsuderman
authored andcommitted
[mlir][tosa] TOSA MLIR dialect update to v0.22, part 1
Incremental set of updates to align to TOSA v0.22 spec - modify gather, resize - add scatter - remove aint8 type Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D99390
1 parent 5f59f40 commit ec46e03

File tree

3 files changed

+62
-28
lines changed

3 files changed

+62
-28
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,17 +1363,38 @@ def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> {
13631363

13641364
let description = [{
13651365
Generate a tensor for which each element in the output is a subtensor of the
1366-
values tensor along the given axis, based on the value of indices.
1366+
values tensor based on the value of indices.
13671367
}];
13681368

13691369
let arguments = (ins
1370-
Tosa_Int32Or64Tensor:$indices,
1371-
Tosa_Tensor1Dto4D:$values,
1372-
I32Attr:$axis
1370+
Tosa_Tensor3D:$values,
1371+
2DTensorOf<[Tosa_Int32]>:$indices
13731372
);
13741373

13751374
let results = (outs
1376-
Tosa_Tensor1Dto4D:$output
1375+
Tosa_Tensor3D:$output
1376+
);
1377+
}
1378+
1379+
//===----------------------------------------------------------------------===//
1380+
// Operator: scatter
1381+
//===----------------------------------------------------------------------===//
1382+
def Tosa_ScatterOp : Tosa_Op<"scatter", [NoSideEffect]> {
1383+
let summary = "Scatter operation,";
1384+
1385+
let description = [{
1386+
The values_out tensor is set to the values_in tensor with data modified as follows:
1387+
data from the input tensor is inserted at the positions specified by the indices tensor.
1388+
}];
1389+
1390+
let arguments = (ins
1391+
Tosa_Tensor3D:$values_in,
1392+
2DTensorOf<[Tosa_Int32]>:$indices,
1393+
Tosa_Tensor3D:$input
1394+
);
1395+
1396+
let results = (outs
1397+
Tosa_Tensor3D:$values_out
13771398
);
13781399
}
13791400

@@ -1402,6 +1423,8 @@ def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> {
14021423
Tosa_IntArrayAttr2:$stride,
14031424
Tosa_IntArrayAttr2:$offset,
14041425
I32Attr:$shift,
1426+
Tosa_Fp32ArrayAttr2:$stride_fp,
1427+
Tosa_Fp32ArrayAttr2:$offset_fp,
14051428
Tosa_ResizeTypeAttr:$mode
14061429
);
14071430

@@ -1462,20 +1485,20 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> {
14621485
let description = [{
14631486
Rescale quantized values into a new domain. Supported rescalings are:
14641487
Mode Input Output
1465-
signed 8 to 8 aint8 aint8
1466-
signed 8 to 16 aint8 int16
1467-
signed 8 to 32 aint8 int32
1468-
signed 16 to 8 int16 aint8
1488+
signed 8 to 8 int8 int8
1489+
signed 8 to 16 int8 int16
1490+
signed 8 to 32 int8 int32
1491+
signed 16 to 8 int16 int8
14691492
signed 16 to 16 int16 int16
14701493
signed 16 to 32 int16 int32
1471-
signed 32 to 8 int32 aint8
1494+
signed 32 to 8 int32 int8
14721495
signed 32 to 16 int32 int16
14731496
signed 32 to 32 int32 int32
1474-
signed 48 to 8 int48 aint8
1497+
signed 48 to 8 int48 int8
14751498
signed 48 to 16 int48 int16
14761499
signed 48 to 32 int48 int32
1477-
unsigned 8 to signed 8 uint8 aint8
1478-
signed 8 to unsigned 8 aint8 uint8
1500+
unsigned 8 to signed 8 uint8 int8
1501+
signed 8 to unsigned 8 int8 uint8
14791502
}];
14801503

14811504
let arguments = (ins

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,12 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
6666
//===----------------------------------------------------------------------===//
6767
// Name Symmetry Grouping Sign
6868
//===----------------------------------------------------------------------===//
69-
// aint8 : asymmetric per tensor, signed
7069
// uint8 : asymmetric per tensor , unsigned
7170
// int4 : symmetric per channel, signed
7271
// int8 : symmetric per tensor/per channel, signed
7372
// int16 : symmetric per tensor, signed
7473
//===----------------------------------------------------------------------===//
75-
def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"aint8", [8], 1>,
76-
Tosa_QuantizedType<"uint8", [8], 0>,
74+
def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
7775
Tosa_QuantizedType<"int4", [4, 0], 1>,
7876
Tosa_QuantizedType<"int8", [8, 0], 1>,
7977
Tosa_QuantizedType<"int16", [16, 0], 1>]>;
@@ -114,6 +112,7 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
114112
// Must be listed rank.
115113
def Tosa_Tensor1D : 1DTensorOf<[Tosa_AnyNumber]>;
116114
def Tosa_Tensor2D : 2DTensorOf<[Tosa_AnyNumber]>;
115+
def Tosa_Tensor3D : 3DTensorOf<[Tosa_AnyNumber]>;
117116
def Tosa_Tensor4D : 4DTensorOf<[Tosa_AnyNumber]>;
118117
def Tosa_Tensor5D : TensorRankOf<[Tosa_AnyNumber], [5]>;
119118
def Tosa_Tensor6D : TensorRankOf<[Tosa_AnyNumber], [6]>;
@@ -149,6 +148,12 @@ class ArrayMaxCt<int n> : AttrConstraint<
149148
CPred<"$_self.cast<::mlir::ArrayAttr>().size() <= " # n>,
150149
"with at least " # n # " elements">;
151150

151+
def Tosa_Fp32ArrayAttr2 : Confined<F32ArrayAttr, [ArrayCount<2>]>;
152+
def Tosa_Fp32ArrayAttr3 : Confined<F32ArrayAttr, [ArrayCount<3>]>;
153+
def Tosa_Fp32ArrayAttr4 : Confined<F32ArrayAttr, [ArrayCount<4>]>;
154+
def Tosa_Fp32ArrayAttr5 : Confined<F32ArrayAttr, [ArrayCount<5>]>;
155+
def Tosa_Fp32ArrayAttr6 : Confined<F32ArrayAttr, [ArrayCount<6>]>;
156+
152157
def Tosa_IntArrayAttr2 : Confined<I64ArrayAttr, [ArrayCount<2>]>;
153158
def Tosa_IntArrayAttr3 : Confined<I64ArrayAttr, [ArrayCount<3>]>;
154159
def Tosa_IntArrayAttr4 : Confined<I64ArrayAttr, [ArrayCount<4>]>;

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -406,18 +406,24 @@ func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
406406

407407
// -----
408408
// CHECK-LABEL: gather
409-
func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<26xi32>) -> tensor<26x21x3xi32> {
410-
%0 = "tosa.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i64} : (tensor<13x21x3xi32>, tensor<26xi32>) -> tensor<26x21x3xi32>
411-
return %0 : tensor<26x21x3xi32>
412-
}
413-
414-
// Test TBD
415-
// DISABLED-CHECK-LABEL: resize
416-
//func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
417-
// %0 = "tosa.const"() {value = dense<64> : tensor<2xi32>} : () -> tensor<2xi32>
418-
// %1 = "tosa.resize"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x32x32x8xf32>, tensor<2xi32>) -> tensor<1x64x64x8xf32>
419-
// return %1 : tensor<1x64x64x8xf32>
420-
//}
409+
func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> {
410+
%0 = "tosa.gather"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x3xf32>
411+
return %0 : tensor<13x26x3xf32>
412+
}
413+
414+
// -----
415+
// CHECK-LABEL: scatter
416+
func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
417+
%0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
418+
return %0 : tensor<13x21x3xf32>
419+
}
420+
421+
// -----
422+
// CHECK-LABEL: resize
423+
func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
424+
%1 = "tosa.resize"(%arg0) {output_size = [64, 64], stride = [1024, 1024], offset = [0, 0], shift = 10 : i32, stride_fp = [0.0 : f32, 0.0 : f32], offset_fp = [0.0 : f32, 0.0 : f32], mode = "BILINEAR"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32>
425+
return %1 : tensor<1x64x64x8xf32>
426+
}
421427

422428
// -----
423429
// CHECK-LABEL: cast

0 commit comments

Comments
 (0)