Skip to content

Commit 032f83b

Browse files
authored
[MLIR][OpenMP] Enable BlockArgOpenMPOpInterface accessing operands (#130769)
This patch makes additions to the `BlockArgOpenMPOpInterface` to simplify its use by letting it handle the matching between operands and their associated entry block arguments. Most significantly, the following is now possible: ```c++ SmallVector<std::pair<Value, BlockArgument>> pairs; cast<BlockArgOpenMPOpInterface>(op).getBlockArgsPairs(pairs); for (auto [var, arg] : pairs) { // var points to the operand (outside value) and arg points to the entry // block argument associated to that value. } ``` This is achieved by making the interface define and use `getXyzVars()` methods, which by default return empty `OperandRange`s and are overriden by getters automatically produced for the `Variadic<...> $xyz_vars` tablegen argument of the corresponding clause. These definitions can then be simplified, since they no longer need to manually define `numXyzBlockArgs` functions as a result. A side-effect of this is that all ops implementing this interface will now publicly define `getXyzVars()` functions for all entry block argument-generating clauses, even if they don't actually accept all clauses. However, these would just return empty ranges, so it shouldn't cause issues. This change uncovered some incorrect definitions of class declarations related to the `ReductionClauseInterface`, and the `OpenMP_DetachClause` incorrectly implementing the `BlockArgOpenMPOpInterface`, so these issues are also addressed.
1 parent c851ee3 commit 032f83b

File tree

5 files changed

+122
-67
lines changed

5 files changed

+122
-67
lines changed

mlir/docs/Dialects/OpenMPDialect/_index.md

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,29 @@ let assemblyFormat = clausesAssemblyFormat # [{
352352
```
353353

354354
The `BlockArgOpenMPOpInterface` has been introduced to simplify the addition and
355-
handling of these kinds of clauses. It holds `num<ClauseName>BlockArgs()`
356-
functions that by default return 0, to be overriden by each clause through the
357-
`extraClassDeclaration` property. Based on these functions and the expected
358-
alphabetical sorting between entry block argument-defining clauses, it
359-
implements `get<ClauseName>BlockArgs()` functions that are the intended method
360-
of accessing clause-defined block arguments.
355+
handling of these kinds of clauses. Adding it to an operation directly, or
356+
indirectly through a clause, results in the addition of overridable
357+
`get<ClauseName>Vars()` and `num<ClauseName>BlockArgs()` public functions for
358+
all entry block argument-generating clauses. By default, the reported number of
359+
block arguments defined by a clause will correspond to the number of operands
360+
taken by the operation for that clause. This list of operands will be empty by
361+
default, and will automatically be overriden by getters of the corresponding
362+
`Variadic<...> $<clause_name>_vars` argument of the same clause's definition.
363+
364+
In addition to these methods added to the actual operations, the
365+
`BlockArgOpenMPOpInterface` itself defines a set of methods based on the
366+
previous ones and on the convention that entry block arguments for multiple
367+
clauses are sorted alphabetically by clause name. These are listed below, and
368+
they represent the main way in which clause-defined block arguments should be
369+
accessed:
370+
- `get<ClauseName>BlockArgsStart()`: Returns the index within the list of
371+
entry block arguments where the first element defined by the given clause
372+
should be located.
373+
- `get<ClauseName>BlockArgs()`: Returns the list of entry block arguments
374+
defined by the given clause.
375+
- `getBlockArgsPairs()`: Returns a list of pairs where the first element is
376+
the outside value, or operand, and the second element is the corresponding
377+
entry block argument.
361378

362379
## Loop-Associated Directives
363380

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -470,12 +470,6 @@ class OpenMP_HasDeviceAddrClauseSkip<
470470
Variadic<OpenMP_PointerLikeType>:$has_device_addr_vars
471471
);
472472

473-
let extraClassDeclaration = [{
474-
unsigned numHasDeviceAddrBlockArgs() {
475-
return getHasDeviceAddrVars().size();
476-
}
477-
}];
478-
479473
let description = [{
480474
The optional `has_device_addr_vars` indicates that list items already have
481475
device addresses, so they may be directly accessed from the target device.
@@ -565,12 +559,6 @@ class OpenMP_HostEvalClauseSkip<
565559
Variadic<AnyType>:$host_eval_vars
566560
);
567561

568-
let extraClassDeclaration = [{
569-
unsigned numHostEvalBlockArgs() {
570-
return getHostEvalVars().size();
571-
}
572-
}];
573-
574562
let description = [{
575563
The optional `host_eval_vars` holds values defined outside of the region of
576564
the `IsolatedFromAbove` operation for which a corresponding entry block
@@ -629,12 +617,10 @@ class OpenMP_InReductionClauseSkip<
629617

630618
let extraClassDeclaration = [{
631619
/// Returns the reduction variables.
632-
SmallVector<Value> getReductionVars() {
620+
SmallVector<Value> getAllReductionVars() {
633621
return SmallVector<Value>(getInReductionVars().begin(),
634622
getInReductionVars().end());
635623
}
636-
637-
unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
638624
}];
639625

640626
// Description varies depending on the operation. Assembly format not defined
@@ -749,6 +735,9 @@ class OpenMP_MapClauseSkip<
749735
Variadic<OpenMP_PointerLikeType>:$map_vars
750736
);
751737

738+
// This assembly format should only be used by operations where `map` does not
739+
// define entry block arguments. Otherwise, it must be printed and parsed
740+
// together with the corresponding region.
752741
let optAssemblyFormat = [{
753742
`map_entries` `(` $map_vars `:` type($map_vars) `)`
754743
}];
@@ -1060,8 +1049,6 @@ class OpenMP_DetachClauseSkip<
10601049
: OpenMP_Clause<traits, arguments, assemblyFormat, description,
10611050
extraClassDeclaration> {
10621051

1063-
let traits = [BlockArgOpenMPOpInterface];
1064-
10651052
let arguments = (ins Optional<OpenMP_PointerLikeType>:$event_handle);
10661053

10671054
let optAssemblyFormat = [{
@@ -1126,10 +1113,6 @@ class OpenMP_PrivateClauseSkip<
11261113
OptionalAttr<SymbolRefArrayAttr>:$private_syms
11271114
);
11281115

1129-
let extraClassDeclaration = [{
1130-
unsigned numPrivateBlockArgs() { return getPrivateVars().size(); }
1131-
}];
1132-
11331116
// TODO: Add description.
11341117
// Assembly format not defined because this clause must be processed together
11351118
// with the first region of the operation, as it defines entry block
@@ -1186,7 +1169,6 @@ class OpenMP_ReductionClauseSkip<
11861169
let extraClassDeclaration = [{
11871170
/// Returns the number of reduction variables.
11881171
unsigned getNumReductionVars() { return getReductionVars().size(); }
1189-
unsigned numReductionBlockArgs() { return getReductionVars().size(); }
11901172
}];
11911173

11921174
// Description varies depending on the operation.
@@ -1316,14 +1298,10 @@ class OpenMP_TaskReductionClauseSkip<
13161298

13171299
let extraClassDeclaration = [{
13181300
/// Returns the reduction variables.
1319-
SmallVector<Value> getReductionVars() {
1301+
SmallVector<Value> getAllReductionVars() {
13201302
return SmallVector<Value>(getTaskReductionVars().begin(),
13211303
getTaskReductionVars().end());
13221304
}
1323-
1324-
unsigned numTaskReductionBlockArgs() {
1325-
return getTaskReductionVars().size();
1326-
}
13271305
}];
13281306

13291307
let description = [{
@@ -1413,12 +1391,6 @@ class OpenMP_UseDeviceAddrClauseSkip<
14131391
Variadic<OpenMP_PointerLikeType>:$use_device_addr_vars
14141392
);
14151393

1416-
let extraClassDeclaration = [{
1417-
unsigned numUseDeviceAddrBlockArgs() {
1418-
return getUseDeviceAddrVars().size();
1419-
}
1420-
}];
1421-
14221394
let description = [{
14231395
The optional `use_device_addr_vars` specifies the address of the objects in
14241396
the device data environment.
@@ -1448,12 +1420,6 @@ class OpenMP_UseDevicePtrClauseSkip<
14481420
Variadic<OpenMP_PointerLikeType>:$use_device_ptr_vars
14491421
);
14501422

1451-
let extraClassDeclaration = [{
1452-
unsigned numUseDevicePtrBlockArgs() {
1453-
return getUseDevicePtrVars().size();
1454-
}
1455-
}];
1456-
14571423
let description = [{
14581424
The optional `use_device_ptr_vars` specifies the device pointers to the
14591425
corresponding list items in the device data environment.

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,8 @@ def SectionOp : OpenMP_Op<"section", traits = [
284284
// Override BlockArgOpenMPOpInterface methods based on the parent
285285
// omp.sections operation. Only forward-declare here because SectionsOp is
286286
// not completely defined at this point.
287-
unsigned numPrivateBlockArgs();
288-
unsigned numReductionBlockArgs();
287+
OperandRange getPrivateVars();
288+
OperandRange getReductionVars();
289289
}] # clausesExtraClassDeclaration;
290290
let assemblyFormat = "$region attr-dict";
291291
}
@@ -824,11 +824,6 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
824824
/// Returns the reduction variables
825825
SmallVector<Value> getAllReductionVars();
826826

827-
// Define BlockArgOpenMPOpInterface methods here because they are not
828-
// inherited from the respective clauses.
829-
unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
830-
unsigned numReductionBlockArgs() { return getReductionVars().size(); }
831-
832827
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
833828
}] # clausesExtraClassDeclaration;
834829

@@ -1151,6 +1146,14 @@ def TargetDataOp: OpenMP_Op<"target_data", traits = [
11511146
OpBuilder<(ins CArg<"const TargetDataOperands &">:$clauses)>
11521147
];
11531148

1149+
let extraClassDeclaration = [{
1150+
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
1151+
// associated entry block arguments in this operation.
1152+
unsigned numMapBlockArgs() {
1153+
return 0;
1154+
}
1155+
}] # clausesExtraClassDeclaration;
1156+
11541157
let assemblyFormat = clausesAssemblyFormat # [{
11551158
custom<UseDeviceAddrUseDevicePtrRegion>(
11561159
$region, $use_device_addr_vars, type($use_device_addr_vars),
@@ -1185,6 +1188,14 @@ def TargetEnterDataOp: OpenMP_Op<"target_enter_data", traits = [
11851188
OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataOperands &">:$clauses)>
11861189
];
11871190

1191+
let extraClassDeclaration = [{
1192+
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
1193+
// associated entry block arguments in this operation.
1194+
unsigned numMapBlockArgs() {
1195+
return 0;
1196+
}
1197+
}] # clausesExtraClassDeclaration;
1198+
11881199
let hasVerifier = 1;
11891200
}
11901201

@@ -1213,6 +1224,14 @@ def TargetExitDataOp: OpenMP_Op<"target_exit_data", traits = [
12131224
OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataOperands &">:$clauses)>
12141225
];
12151226

1227+
let extraClassDeclaration = [{
1228+
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
1229+
// associated entry block arguments in this operation.
1230+
unsigned numMapBlockArgs() {
1231+
return 0;
1232+
}
1233+
}] # clausesExtraClassDeclaration;
1234+
12161235
let hasVerifier = 1;
12171236
}
12181237

@@ -1249,6 +1268,14 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
12491268
OpBuilder<(ins CArg<"const TargetEnterExitUpdateDataOperands &">:$clauses)>
12501269
];
12511270

1271+
let extraClassDeclaration = [{
1272+
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
1273+
// associated entry block arguments in this operation.
1274+
unsigned numMapBlockArgs() {
1275+
return 0;
1276+
}
1277+
}] # clausesExtraClassDeclaration;
1278+
12521279
let hasVerifier = 1;
12531280
}
12541281

@@ -1292,8 +1319,6 @@ def TargetOp : OpenMP_Op<"target", traits = [
12921319
];
12931320

12941321
let extraClassDeclaration = [{
1295-
unsigned numMapBlockArgs() { return getMapVars().size(); }
1296-
12971322
mlir::Value getMappedValueForPrivateVar(unsigned privVarIdx) {
12981323
std::optional<DenseI64ArrayAttr> privateMapIdices = getPrivateMapsAttr();
12991324

@@ -1818,6 +1843,14 @@ def DeclareMapperInfoOp : OpenMP_Op<"declare_mapper.info", [
18181843
OpBuilder<(ins CArg<"const DeclareMapperInfoOperands &">:$clauses)>
18191844
];
18201845

1846+
let extraClassDeclaration = [{
1847+
// Override BlockArgOpenMPOpInterface method because `map` clauses have no
1848+
// associated entry block arguments in this operation.
1849+
unsigned numMapBlockArgs() {
1850+
return 0;
1851+
}
1852+
}] # clausesExtraClassDeclaration;
1853+
18211854
let hasVerifier = 1;
18221855
}
18231856

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,41 @@ include "mlir/IR/OpBase.td"
2323
// arguments corresponding to each of these clauses.
2424
class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
2525
BlockArgOpenMPClause previousClause> {
26-
// Default-implemented method to be overriden by the corresponding clause.
26+
// Default-implemented method, overriden by the corresponding clause. It
27+
// returns the range of operands passed to the operation associated to the
28+
// clause.
29+
//
30+
// For the override to work, the clause tablegen definition must contain a
31+
// `Variadic<...> $clause_name_vars` argument.
2732
//
2833
// Usage example:
2934
//
3035
// ```c++
31-
// auto iface = cast<BlockArgOpenMPOpInterface>(op);
32-
// unsigned numInReductionArgs = iface.numInReductionBlockArgs();
36+
// OperandRange reductionVars = op.getReductionVars();
37+
// ```
38+
InterfaceMethod varsMethod = InterfaceMethod<
39+
"Get operation operands associated to `" # clauseNameSnake # "`.",
40+
"::mlir::OperandRange", "get" # clauseNameCamel # "Vars", (ins), [{}], [{
41+
return {0, 0};
42+
}]
43+
>;
44+
45+
// It returns the number of entry block arguments introduced by the given
46+
// clause.
47+
//
48+
// By default, it will be the number of operands corresponding to that clause,
49+
// but it can be overriden by operations where this might not be the case
50+
// (e.g. `map` clause in `omp.target_update`).
51+
//
52+
// Usage example:
53+
//
54+
// ```c++
55+
// unsigned numInReductionArgs = op.numInReductionBlockArgs();
3356
// ```
3457
InterfaceMethod numArgsMethod = InterfaceMethod<
3558
"Get number of block arguments defined by `" # clauseNameSnake # "`.",
36-
"unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], [{
37-
return 0;
38-
}]
59+
"unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}],
60+
"return $_op." # varsMethod.name # "().size();"
3961
>;
4062

4163
// Unified access method for the start index of clause-associated entry block
@@ -52,7 +74,7 @@ class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
5274
"unsigned", "get" # clauseNameCamel # "BlockArgsStart", (ins),
5375
!if(!initialized(previousClause), [{
5476
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
55-
}] # "return iface." # previousClause.startMethod.name # "() + $_op."
77+
}] # "return iface." # previousClause.startMethod.name # "() + iface."
5678
# previousClause.numArgsMethod.name # "();",
5779
"return 0;"
5880
)
@@ -72,7 +94,7 @@ class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
7294
"get" # clauseNameCamel # "BlockArgs", (ins), [{
7395
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
7496
return $_op->getRegion(0).getArguments().slice(
75-
}] # "iface." # startMethod.name # "(), $_op." # numArgsMethod.name # "());"
97+
}] # "iface." # startMethod.name # "(), iface." # numArgsMethod.name # "());"
7698
>;
7799
}
78100

@@ -109,9 +131,26 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
109131
BlockArgUseDeviceAddrClause, BlockArgUseDevicePtrClause ];
110132

111133
let methods = !listconcat(
134+
!foreach(clause, clauses, clause.varsMethod),
112135
!foreach(clause, clauses, clause.numArgsMethod),
113136
!foreach(clause, clauses, clause.startMethod),
114-
!foreach(clause, clauses, clause.blockArgsMethod)
137+
!foreach(clause, clauses, clause.blockArgsMethod),
138+
[
139+
InterfaceMethod<
140+
"Populate a vector of pairs representing the matching between operands "
141+
"and entry block arguments.", "void", "getBlockArgsPairs",
142+
(ins "::llvm::SmallVectorImpl<std::pair<::mlir::Value, ::mlir::BlockArgument>> &" : $pairs),
143+
[{
144+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
145+
}] # !interleave(!foreach(clause, clauses, [{
146+
}] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{
147+
}] # " for (auto [var, arg] : ::llvm::zip_equal(" #
148+
"iface." # clause.varsMethod.name # "()," #
149+
"iface." # clause.blockArgsMethod.name # "()))" # [{
150+
pairs.emplace_back(var, arg);
151+
} }]), "\n")
152+
>
153+
]
115154
);
116155

117156
let verify = [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2248,12 +2248,12 @@ LogicalResult TeamsOp::verify() {
22482248
// SectionOp
22492249
//===----------------------------------------------------------------------===//
22502250

2251-
unsigned SectionOp::numPrivateBlockArgs() {
2252-
return getParentOp().numPrivateBlockArgs();
2251+
OperandRange SectionOp::getPrivateVars() {
2252+
return getParentOp().getPrivateVars();
22532253
}
22542254

2255-
unsigned SectionOp::numReductionBlockArgs() {
2256-
return getParentOp().numReductionBlockArgs();
2255+
OperandRange SectionOp::getReductionVars() {
2256+
return getParentOp().getReductionVars();
22572257
}
22582258

22592259
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)