Skip to content

Commit d7c44a5

Browse files
committed
[mlir][tosa] Fix tosa.mul to use tosa.apply_scale
Multiply-shift requires wider compute types or CPU specific code to avoid premature truncation, apply_shift fixes this issue Also, Tosa's mul op supports different input / output types. Added path that sign-extends input values to int-32 values before multiplying. Differential Revision: https://reviews.llvm.org/D99011
1 parent 5727df2 commit d7c44a5

File tree

2 files changed

+69
-24
lines changed

2 files changed

+69
-24
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,39 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
115115
}
116116

117117
if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
118-
auto mul =
119-
rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], args[1]);
120-
auto constant =
121-
rewriter.create<mlir::ConstantOp>(loc, elementTy, op->getAttr("shift"));
122-
return rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, mul,
123-
constant);
118+
Value a = args[0];
119+
Value b = args[1];
120+
auto shift =
121+
op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
122+
if (shift > 0) {
123+
auto shiftConst =
124+
rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8);
125+
if (!a.getType().isInteger(32))
126+
a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a);
127+
128+
if (!b.getType().isInteger(32))
129+
b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b);
130+
131+
auto result = rewriter.create<tosa::ApplyScaleOp>(
132+
loc, rewriter.getI32Type(), a, b, shiftConst,
133+
rewriter.getBoolAttr(false));
134+
135+
if (elementTy.isInteger(32))
136+
return result;
137+
138+
return rewriter.create<TruncateIOp>(loc, elementTy, result);
139+
}
140+
141+
int aWidth = a.getType().getIntOrFloatBitWidth();
142+
int bWidth = b.getType().getIntOrFloatBitWidth();
143+
int cWidth = resultTypes[0].getIntOrFloatBitWidth();
144+
145+
if (aWidth < cWidth)
146+
a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a);
147+
if (bWidth < cWidth)
148+
b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b);
149+
150+
return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b);
124151
}
125152

126153
// tosa::NegateOp

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,19 @@ func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
214214

215215
// -----
216216

217+
// CHECK-LABEL: @test_simple_i16
218+
func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
219+
// CHECK: linalg.generic
220+
// CHECK: sext
221+
// CHECK: sext
222+
// CHECK: muli
223+
%0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
224+
225+
return
226+
}
227+
228+
// -----
229+
217230
// CHECK-LABEL: @test_simple_i32
218231
func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
219232
// CHECK: linalg.generic
@@ -228,82 +241,87 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
228241
// CHECK: muli
229242
%2 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
230243

244+
// CHECK: linalg.generic
245+
// CHECK: constant 2
246+
// CHECK: apply_scale
247+
%3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
248+
231249
// CHECK: linalg.generic
232250
// CHECK: muli
233-
%3 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
251+
%4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
234252

235253
// CHECK: linalg.generic
236254
// CHECK: and
237-
%4 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
255+
%5 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
238256

239257
// CHECK: linalg.generic
240258
// CHECK: or
241-
%5 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
259+
%6 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
242260

243261
// CHECK: linalg.generic
244262
// CHECK: xor
245-
%6 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
263+
%7 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
246264

247265
// CHECK: linalg.generic
248266
// CHECK: shift_left
249-
%7 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
267+
%8 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
250268

251269
// CHECK: linalg.generic
252270
// CHECK: shift_right_unsigned
253-
%8 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
271+
%9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
254272

255273
// CHECK: linalg.generic
256274
// CHECK: cmpi
257-
%9 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
275+
%10 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
258276

259277
// CHECK: linalg.generic
260278
// CHECK: cmpi
261-
%10 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
279+
%11 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
262280

263281
// CHECK: linalg.generic
264282
// CHECK: select
265-
%11 = "tosa.select"(%9, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
283+
%12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
266284

267285
// CHECK: linalg.generic
268286
// CHECK: cmpi
269287
// CHECK: select
270-
%12 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
288+
%13 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
271289

272290
// CHECK: linalg.generic
273291
// CHECK: cmpi
274292
// CHECK: select
275-
%13 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
293+
%14 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
276294

277295
// CHECK: linalg.generic
278296
// CHECK: cmpi
279297
// CHECK: select
280-
%14 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
298+
%15 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
281299

282300
// CHECK: linalg.generic
283301
// CHECK: cmpi
284302
// CHECK: select
285-
%15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
303+
%16 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
286304

287305
// CHECK: linalg.generic
288306
// CHECK: trunci
289-
%16 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
307+
%17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
290308

291309
// CHECK: linalg.generic
292310
// CHECK: yield
293-
%17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
311+
%18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
294312

295313
// CHECK: linalg.generic
296314
// CHECK: sexti
297-
%18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
315+
%19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
298316

299317
// CHECK: linalg.generic
300318
// CHECK: constant 0
301319
// CHECK: cmpi
302-
%19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
320+
%20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
303321

304322
// CHECK: linalg.generic
305323
// CHECK: sitofp
306-
%20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
324+
%21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
307325

308326
return
309327
}

0 commit comments

Comments
 (0)