Skip to content

Commit da4d191

Browse files
tatwaichongeric-k256
authored andcommitted
[mlir][tosa] operation printing syntax prettification
The initial patch defines the printing format of tosa operations in declarative and C++ manners to make alignment against more readable syntax in other dialects. The general change to assembly output is shown below. from %out = "tosa.op"(%input1, %input2, ...) : (type1, type2, ...) -> (out_type) to %out = tosa.op %input1, %input2, ... : (type1, type2, ...) -> out_type There is a significant structural printing change to tosa control-flow operations, `cond_if` and `while_loop`, aiming to provide more concise and intuitive syntax. Note that we leave tosa.const unchanged. As this op can be attached with quantization information, may need more tweaks to distinguish plain integer type from quantized type for printing the value in a concise form. Differential Revision: https://reviews.llvm.org/D155231
1 parent 7a2dfe2 commit da4d191

20 files changed

+1370
-1174
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
208208
}]>;
209209

210210
//===----------------------------------------------------------------------===//
211-
// TOSA Operator.
211+
// TOSA Operator Class.
212212
//===----------------------------------------------------------------------===//
213213

214214
class Tosa_Op<string mnemonic, list<Trait> traits = []> :
@@ -221,6 +221,20 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
221221
["inferReturnTypeComponents"]>,
222222
ResultsBroadcastableShape,
223223
Pure])> {
224+
let assemblyFormat =
225+
"operands attr-dict `:` functional-type(operands, results)";
226+
}
227+
228+
class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>
229+
: Tosa_Op<mnemonic, !listconcat(traits, [InferTensorTypeAdaptor, Pure])> {
230+
let assemblyFormat =
231+
"operands attr-dict `:` functional-type(operands, results)";
232+
}
233+
234+
class Tosa_InferShapedTypeOp<string mnemonic, list<Trait> traits = []>
235+
: Tosa_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpAdaptor, Pure])> {
236+
let assemblyFormat =
237+
"operands attr-dict `:` functional-type(operands, results)";
224238
}
225239

226240
#endif // TOSA_OP_BASE

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 107 additions & 42 deletions
Large diffs are not rendered by default.

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def Tosa_ApplyScaleOp :
5454
let extraClassDeclaration = [{
5555
std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
5656
}];
57+
58+
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
5759
}
5860

5961
//===----------------------------------------------------------------------===//
@@ -73,6 +75,8 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
7375
let arguments = (ins
7476
Variadic<Tosa_Tensor>:$inputs
7577
);
78+
79+
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
7680
}
7781

7882
#endif // TOSA_UTIL_OPS

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,134 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
14931493
return std::nullopt;
14941494
}
14951495

1496+
// parse and print of IfOp refer to the implementation of SCF dialect.
1497+
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1498+
// Create the regions for 'then'.
1499+
result.regions.reserve(2);
1500+
Region *thenRegion = result.addRegion();
1501+
Region *elseRegion = result.addRegion();
1502+
1503+
auto &builder = parser.getBuilder();
1504+
OpAsmParser::UnresolvedOperand cond;
1505+
// Create a i1 tensor type for the boolean condition.
1506+
Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
1507+
if (parser.parseOperand(cond) ||
1508+
parser.resolveOperand(cond, i1Type, result.operands))
1509+
return failure();
1510+
// Parse optional results type list.
1511+
if (parser.parseOptionalArrowTypeList(result.types))
1512+
return failure();
1513+
// Parse the 'then' region.
1514+
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1515+
return failure();
1516+
1517+
// If we find an 'else' keyword then parse the 'else' region.
1518+
if (!parser.parseOptionalKeyword("else")) {
1519+
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1520+
return failure();
1521+
}
1522+
1523+
// Parse the optional attribute list.
1524+
if (parser.parseOptionalAttrDict(result.attributes))
1525+
return failure();
1526+
return success();
1527+
}
1528+
1529+
void IfOp::print(OpAsmPrinter &p) {
1530+
bool printBlockTerminators = false;
1531+
1532+
p << " " << getCond();
1533+
if (!getResults().empty()) {
1534+
p << " -> (" << getResultTypes() << ")";
1535+
// Print yield explicitly if the op defines values.
1536+
printBlockTerminators = true;
1537+
}
1538+
p << ' ';
1539+
p.printRegion(getThenBranch(),
1540+
/*printEntryBlockArgs=*/false,
1541+
/*printBlockTerminators=*/printBlockTerminators);
1542+
1543+
// Print the 'else' regions if it exists and has a block.
1544+
auto &elseRegion = getElseBranch();
1545+
if (!elseRegion.empty()) {
1546+
p << " else ";
1547+
p.printRegion(elseRegion,
1548+
/*printEntryBlockArgs=*/false,
1549+
/*printBlockTerminators=*/printBlockTerminators);
1550+
}
1551+
1552+
p.printOptionalAttrDict((*this)->getAttrs());
1553+
}
1554+
1555+
// parse and print of WhileOp refer to the implementation of SCF dialect.
1556+
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
1557+
SmallVector<OpAsmParser::Argument, 4> regionArgs;
1558+
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1559+
Region *cond = result.addRegion();
1560+
Region *body = result.addRegion();
1561+
1562+
OptionalParseResult listResult =
1563+
parser.parseOptionalAssignmentList(regionArgs, operands);
1564+
if (listResult.has_value() && failed(listResult.value()))
1565+
return failure();
1566+
1567+
FunctionType functionType;
1568+
SMLoc typeLoc = parser.getCurrentLocation();
1569+
if (failed(parser.parseColonType(functionType)))
1570+
return failure();
1571+
1572+
result.addTypes(functionType.getResults());
1573+
1574+
if (functionType.getNumInputs() != operands.size()) {
1575+
return parser.emitError(typeLoc)
1576+
<< "expected as many input types as operands "
1577+
<< "(expected " << operands.size() << " got "
1578+
<< functionType.getNumInputs() << ")";
1579+
}
1580+
1581+
// Resolve input operands.
1582+
if (failed(parser.resolveOperands(operands, functionType.getInputs(),
1583+
parser.getCurrentLocation(),
1584+
result.operands)))
1585+
return failure();
1586+
1587+
// Propagate the types into the region arguments.
1588+
for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
1589+
regionArgs[i].type = functionType.getInput(i);
1590+
1591+
return failure(parser.parseRegion(*cond, regionArgs) ||
1592+
parser.parseKeyword("do") || parser.parseRegion(*body) ||
1593+
parser.parseOptionalAttrDictWithKeyword(result.attributes));
1594+
}
1595+
1596+
static void printInitializationList(OpAsmPrinter &parser,
1597+
Block::BlockArgListType blocksArgs,
1598+
ValueRange initializers,
1599+
StringRef prefix = "") {
1600+
assert(blocksArgs.size() == initializers.size() &&
1601+
"expected same length of arguments and initializers");
1602+
if (initializers.empty())
1603+
return;
1604+
1605+
parser << prefix << '(';
1606+
llvm::interleaveComma(
1607+
llvm::zip(blocksArgs, initializers), parser,
1608+
[&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
1609+
parser << ")";
1610+
}
1611+
1612+
void WhileOp::print(OpAsmPrinter &parser) {
1613+
printInitializationList(parser, getCond().front().getArguments(), getInputs(),
1614+
" ");
1615+
parser << " : ";
1616+
parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
1617+
parser << ' ';
1618+
parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
1619+
parser << " do ";
1620+
parser.printRegion(getBody());
1621+
parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1622+
}
1623+
14961624
//===----------------------------------------------------------------------===//
14971625
// TOSA Attribute Definitions.
14981626
//===----------------------------------------------------------------------===//

mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func.func @const_test() -> (tensor<i32>) {
1313
// -----
1414

1515
// CHECK-LABEL: @apply_scale_test_i32
16-
// SCALE: "tosa.apply_scale"
16+
// SCALE: tosa.apply_scale
1717
func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
1818
// CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32
1919
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
@@ -67,24 +67,24 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
6767
// CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
6868
// CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
6969
// CHECK: return %[[RESULT]]
70-
%res = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32
70+
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32
7171
return %res : i32
7272
}
7373

7474
// -----
7575

7676
// CHECK-LABEL: @apply_scale_test_vector
77-
// SCALE: "tosa.apply_scale"
77+
// SCALE: tosa.apply_scale
7878
func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
7979
// CHECK-NOT: "tosa.apply_scale"
80-
%res = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
80+
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
8181
return %res : vector<4xi32>
8282
}
8383

8484
// -----
8585

8686
// CHECK-LABEL: @apply_scale_test_i48
87-
// SCALE: "tosa.apply_scale"
87+
// SCALE: tosa.apply_scale
8888
func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
8989
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i48
9090
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64
@@ -115,6 +115,6 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
115115
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
116116
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
117117
// CHECK: return %[[TRUNC]]
118-
%res = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i48, i32, i8) -> i32
118+
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
119119
return %res : i32
120120
}

0 commit comments

Comments
 (0)