Skip to content

Commit c1e4e01

Browse files
committed
[mlir][OpenMP] Added assemblyFormat for SectionsOp
This patch adds assemblyFormat for omp.sections operation. Some existing functions have been altered to fit the custom directive in assemblyFormat. This has led to their callsites to get modified too, but those will be removed in later patches, when other operations get their assemblyFormat. All operations were not changed in one patch for ease of review. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D120176
1 parent c1f17b0 commit c1e4e01

File tree

4 files changed

+49
-85
lines changed

4 files changed

+49
-85
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,20 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
188188

189189
let regions = (region SizedRegion<1>:$region);
190190

191-
let hasCustomAssemblyFormat = 1;
191+
let assemblyFormat = [{
192+
oilist( `reduction` `(`
193+
custom<ReductionVarList>(
194+
$reduction_vars, type($reduction_vars), $reductions
195+
) `)`
196+
| `allocate` `(`
197+
custom<AllocateAndAllocator>(
198+
$allocate_vars, type($allocate_vars),
199+
$allocators_vars, type($allocators_vars)
200+
) `)`
201+
| `nowait`
202+
) $region attr-dict
203+
}];
204+
192205
let hasVerifier = 1;
193206
}
194207

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 25 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
7777

7878
/// Parse an allocate clause with allocators and a list of operands with types.
7979
///
80-
/// allocate ::= `allocate` `(` allocate-operand-list `)`
8180
/// allocate-operand-list :: = allocate-operand |
8281
/// allocator-operand `,` allocate-operand-list
8382
/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
@@ -300,39 +299,35 @@ static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
300299
// Parser, printer and verifier for ReductionVarList
301300
//===----------------------------------------------------------------------===//
302301

303-
/// reduction ::= `reduction` `(` reduction-entry-list `)`
304302
/// reduction-entry-list ::= reduction-entry
305303
/// | reduction-entry-list `,` reduction-entry
306304
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
307-
static ParseResult
308-
parseReductionVarList(OpAsmParser &parser,
309-
SmallVectorImpl<SymbolRefAttr> &symbols,
310-
SmallVectorImpl<OpAsmParser::OperandType> &operands,
311-
SmallVectorImpl<Type> &types) {
312-
if (failed(parser.parseLParen()))
313-
return failure();
314-
305+
static ParseResult parseReductionVarList(
306+
OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands,
307+
SmallVectorImpl<Type> &types, ArrayAttr &redcuctionSymbols) {
308+
SmallVector<SymbolRefAttr> reductionVec;
315309
do {
316-
if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
317-
parser.parseOperand(operands.emplace_back()) ||
310+
if (parser.parseAttribute(reductionVec.emplace_back()) ||
311+
parser.parseArrow() || parser.parseOperand(operands.emplace_back()) ||
318312
parser.parseColonType(types.emplace_back()))
319313
return failure();
320314
} while (succeeded(parser.parseOptionalComma()));
321-
return parser.parseRParen();
315+
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
316+
redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
317+
return success();
322318
}
323319

324320
/// Print Reduction clause
325-
static void printReductionVarList(OpAsmPrinter &p,
326-
Optional<ArrayAttr> reductions,
327-
OperandRange reductionVars) {
328-
p << "reduction(";
321+
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
322+
OperandRange reductionVars,
323+
TypeRange reductionTypes,
324+
Optional<ArrayAttr> reductions) {
329325
for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
330326
if (i != 0)
331327
p << ", ";
332328
p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
333329
<< reductionVars[i].getType();
334330
}
335-
p << ") ";
336331
}
337332

338333
/// Verifies Reduction Clause
@@ -552,7 +547,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
552547
SmallVector<OpAsmParser::OperandType> allocates, allocators;
553548
SmallVector<Type> allocateTypes, allocatorTypes;
554549

555-
SmallVector<SymbolRefAttr> reductionSymbols;
550+
ArrayAttr reductions;
556551
SmallVector<OpAsmParser::OperandType> reductionVars;
557552
SmallVector<Type> reductionVarTypes;
558553

@@ -639,9 +634,10 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
639634
"proc_bind_val", "proc bind"))
640635
return failure();
641636
} else if (clauseKeyword == "reduction") {
642-
if (checkAllowed(reductionClause) ||
643-
parseReductionVarList(parser, reductionSymbols, reductionVars,
644-
reductionVarTypes))
637+
if (checkAllowed(reductionClause) || parser.parseLParen() ||
638+
parseReductionVarList(parser, reductionVars, reductionVarTypes,
639+
reductions) ||
640+
parser.parseRParen())
645641
return failure();
646642
clauseSegments[pos[reductionClause]] = reductionVars.size();
647643
} else if (clauseKeyword == "nowait") {
@@ -746,11 +742,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
746742
if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
747743
parser.getNameLoc(), result.operands)))
748744
return failure();
749-
750-
SmallVector<Attribute> reductions(reductionSymbols.begin(),
751-
reductionSymbols.end());
752-
result.addAttribute("reductions",
753-
parser.getBuilder().getArrayAttr(reductions));
745+
result.addAttribute("reductions", reductions);
754746
}
755747

756748
// Add linear parameters
@@ -805,53 +797,9 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
805797
}
806798

807799
//===----------------------------------------------------------------------===//
808-
// Parser, printer and verifier for SectionsOp
800+
// Verifier for SectionsOp
809801
//===----------------------------------------------------------------------===//
810802

811-
/// Parses an OpenMP Sections operation
812-
///
813-
/// sections ::= `omp.sections` clause-list
814-
/// clause-list ::= clause clause-list | empty
815-
/// clause ::= reduction | allocate | nowait
816-
ParseResult SectionsOp::parse(OpAsmParser &parser, OperationState &result) {
817-
SmallVector<ClauseType> clauses = {reductionClause, allocateClause,
818-
nowaitClause};
819-
820-
SmallVector<int> segments;
821-
822-
if (failed(parseClauses(parser, result, clauses, segments)))
823-
return failure();
824-
825-
result.addAttribute("operand_segment_sizes",
826-
parser.getBuilder().getI32VectorAttr(segments));
827-
828-
// Now parse the body.
829-
Region *body = result.addRegion();
830-
if (parser.parseRegion(*body))
831-
return failure();
832-
return success();
833-
}
834-
835-
void SectionsOp::print(OpAsmPrinter &p) {
836-
p << " ";
837-
838-
if (!reduction_vars().empty())
839-
printReductionVarList(p, reductions(), reduction_vars());
840-
841-
if (!allocate_vars().empty()) {
842-
printAllocateAndAllocator(p << "allocate(", *this, allocate_vars(),
843-
allocate_vars().getTypes(), allocators_vars(),
844-
allocators_vars().getTypes());
845-
p << ")";
846-
}
847-
848-
if (nowait())
849-
p << "nowait";
850-
851-
p << ' ';
852-
p.printRegion(region());
853-
}
854-
855803
LogicalResult SectionsOp::verify() {
856804
if (allocate_vars().size() != allocators_vars().size())
857805
return emitError(
@@ -960,8 +908,11 @@ void WsLoopOp::print(OpAsmPrinter &p) {
960908
if (auto order = order_val())
961909
p << "order(" << stringifyClauseOrderKind(*order) << ") ";
962910

963-
if (!reduction_vars().empty())
964-
printReductionVarList(p, reductions(), reduction_vars());
911+
if (!reduction_vars().empty()) {
912+
printReductionVarList(p << "reduction(", *this, reduction_vars(),
913+
reduction_vars().getTypes(), reductions());
914+
p << ")";
915+
}
965916

966917
p << ' ';
967918
p.printRegion(region(), /*printEntryBlockArgs=*/false);

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ func @omp_sections(%data_var : memref<i32>) -> () {
793793
// -----
794794

795795
func @omp_sections(%cond : i1) {
796-
// expected-error @below {{if is not a valid clause for the omp.sections operation}}
796+
// expected-error @below {{expected '{' to begin a region}}
797797
omp.sections if(%cond) {
798798
omp.terminator
799799
}
@@ -803,7 +803,7 @@ func @omp_sections(%cond : i1) {
803803
// -----
804804

805805
func @omp_sections() {
806-
// expected-error @below {{num_threads is not a valid clause for the omp.sections operation}}
806+
// expected-error @below {{expected '{' to begin a region}}
807807
omp.sections num_threads(10) {
808808
omp.terminator
809809
}
@@ -813,7 +813,7 @@ func @omp_sections() {
813813
// -----
814814

815815
func @omp_sections() {
816-
// expected-error @below {{proc_bind is not a valid clause for the omp.sections operation}}
816+
// expected-error @below {{expected '{' to begin a region}}
817817
omp.sections proc_bind(close) {
818818
omp.terminator
819819
}
@@ -823,7 +823,7 @@ func @omp_sections() {
823823
// -----
824824

825825
func @omp_sections(%data_var : memref<i32>, %linear_var : i32) {
826-
// expected-error @below {{linear is not a valid clause for the omp.sections operation}}
826+
// expected-error @below {{expected '{' to begin a region}}
827827
omp.sections linear(%data_var = %linear_var : memref<i32>) {
828828
omp.terminator
829829
}
@@ -833,7 +833,7 @@ func @omp_sections(%data_var : memref<i32>, %linear_var : i32) {
833833
// -----
834834

835835
func @omp_sections() {
836-
// expected-error @below {{schedule is not a valid clause for the omp.sections operation}}
836+
// expected-error @below {{expected '{' to begin a region}}
837837
omp.sections schedule(static, none) {
838838
omp.terminator
839839
}
@@ -843,7 +843,7 @@ func @omp_sections() {
843843
// -----
844844

845845
func @omp_sections() {
846-
// expected-error @below {{collapse is not a valid clause for the omp.sections operation}}
846+
// expected-error @below {{expected '{' to begin a region}}
847847
omp.sections collapse(3) {
848848
omp.terminator
849849
}
@@ -853,7 +853,7 @@ func @omp_sections() {
853853
// -----
854854

855855
func @omp_sections() {
856-
// expected-error @below {{ordered is not a valid clause for the omp.sections operation}}
856+
// expected-error @below {{expected '{' to begin a region}}
857857
omp.sections ordered(2) {
858858
omp.terminator
859859
}
@@ -863,7 +863,7 @@ func @omp_sections() {
863863
// -----
864864

865865
func @omp_sections() {
866-
// expected-error @below {{order is not a valid clause for the omp.sections operation}}
866+
// expected-error @below {{expected '{' to begin a region}}
867867
omp.sections order(concurrent) {
868868
omp.terminator
869869
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,13 +624,13 @@ func @omp_sectionsop(%data_var1 : memref<i32>, %data_var2 : memref<i32>,
624624
"omp.sections" (%data_var1, %data_var1) ({
625625
// CHECK: omp.terminator
626626
omp.terminator
627-
}) {operand_segment_sizes = dense<[0,1,1]> : vector<3xi32>} : (memref<i32>, memref<i32>) -> ()
627+
}) {allocate, operand_segment_sizes = dense<[0,1,1]> : vector<3xi32>} : (memref<i32>, memref<i32>) -> ()
628628

629629
// CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr<f32>)
630630
"omp.sections" (%redn_var) ({
631631
// CHECK: omp.terminator
632632
omp.terminator
633-
}) {operand_segment_sizes = dense<[1,0,0]> : vector<3xi32>, reductions=[@add_f32]} : (!llvm.ptr<f32>) -> ()
633+
}) {reduction, operand_segment_sizes = dense<[1,0,0]> : vector<3xi32>, reductions=[@add_f32]} : (!llvm.ptr<f32>) -> ()
634634

635635
// CHECK: omp.sections nowait {
636636
omp.sections nowait {

0 commit comments

Comments
 (0)