Skip to content

Commit db4c94f

Browse files
authored
[OpenACC][CIR] Implement beginning of 'copy' lowering for compute con… (llvm#140304)
…structs This is a partial implementation of the 'copy' lowering. It is missing 3 things, which are coming in future patches: 1- does not handle subscript/subarrays for emission as variables 2- does not handle member expressions for emissions as variables 3- does not handle modifier-list 1 and 2 are because of the complexity and should be split off into a separate patch. 3 is because it isn't clear how the IR is going to handle this, and I'd like to make sure it gets done 'all at once' when the IR is updated to handle these, so I'm pushing that off to the future. This DOES however handle the complexity of having a acc.copyin and acc.copyout, plus the additional complexity of the 'async' clause.
1 parent f3f63ce commit db4c94f

File tree

4 files changed

+437
-33
lines changed

4 files changed

+437
-33
lines changed

clang/lib/CIR/CodeGen/CIRGenOpenACCClause.h

Lines changed: 220 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/OpenACC/OpenACC.h"
17+
#include "llvm/ADT/TypeSwitch.h"
1718
namespace clang {
1819
// Simple type-trait to see if the first template arg is one of the list, so we
1920
// can tell whether to `if-constexpr` a bunch of stuff.
@@ -36,6 +37,72 @@ template <typename ToTest> constexpr bool isCombinedType = false;
3637
template <typename T>
3738
constexpr bool isCombinedType<CombinedConstructClauseInfo<T>> = true;
3839

40+
namespace {
41+
struct DataOperandInfo {
42+
mlir::Location beginLoc;
43+
mlir::Value varValue;
44+
llvm::StringRef name;
45+
};
46+
47+
inline mlir::Value emitOpenACCIntExpr(CIRGen::CIRGenFunction &cgf,
48+
CIRGen::CIRGenBuilderTy &builder,
49+
const Expr *intExpr) {
50+
mlir::Value expr = cgf.emitScalarExpr(intExpr);
51+
mlir::Location exprLoc = cgf.cgm.getLoc(intExpr->getBeginLoc());
52+
53+
mlir::IntegerType targetType = mlir::IntegerType::get(
54+
&cgf.getMLIRContext(), cgf.getContext().getIntWidth(intExpr->getType()),
55+
intExpr->getType()->isSignedIntegerOrEnumerationType()
56+
? mlir::IntegerType::SignednessSemantics::Signed
57+
: mlir::IntegerType::SignednessSemantics::Unsigned);
58+
59+
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
60+
exprLoc, targetType, expr);
61+
return conversionOp.getResult(0);
62+
}
63+
64+
// A helper function that gets the information from an operand to a data
65+
// clause, so that it can be used to emit the data operations.
66+
inline DataOperandInfo getDataOperandInfo(CIRGen::CIRGenFunction &cgf,
67+
CIRGen::CIRGenBuilderTy &builder,
68+
OpenACCDirectiveKind dk,
69+
const Expr *e) {
70+
// TODO: OpenACC: Cache was different enough as to need a separate
71+
// `ActOnCacheVar`, so we are going to need to do some investigations here
72+
// when it comes to implement this for cache.
73+
if (dk == OpenACCDirectiveKind::Cache) {
74+
cgf.cgm.errorNYI(e->getSourceRange(),
75+
"OpenACC data operand for 'cache' directive");
76+
return {cgf.cgm.getLoc(e->getBeginLoc()), {}, {}};
77+
}
78+
79+
const Expr *curVarExpr = e->IgnoreParenImpCasts();
80+
81+
mlir::Location exprLoc = cgf.cgm.getLoc(curVarExpr->getBeginLoc());
82+
83+
// TODO: OpenACC: Assemble the list of bounds.
84+
if (isa<ArraySectionExpr, ArraySubscriptExpr>(curVarExpr)) {
85+
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
86+
"OpenACC data clause array subscript/section");
87+
return {exprLoc, {}, {}};
88+
}
89+
90+
// TODO: OpenACC: if this is a member expr, emit the VarPtrPtr correctly.
91+
if (isa<MemberExpr>(curVarExpr)) {
92+
cgf.cgm.errorNYI(curVarExpr->getSourceRange(),
93+
"OpenACC Data clause member expr");
94+
return {exprLoc, {}, {}};
95+
}
96+
97+
// Sema has made sure that only 4 types of things can get here, array
98+
// subscript, array section, member expr, or DRE to a var decl (or the former
99+
// 3 wrapping a var-decl), so we should be able to assume this is right.
100+
const auto *dre = cast<DeclRefExpr>(curVarExpr);
101+
const auto *vd = cast<VarDecl>(dre->getFoundDecl()->getCanonicalDecl());
102+
return {exprLoc, cgf.emitDeclRefLValue(dre).getPointer(), vd->getName()};
103+
}
104+
} // namespace
105+
39106
template <typename OpTy>
40107
class OpenACCClauseCIREmitter final
41108
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter<OpTy>> {
@@ -54,6 +121,11 @@ class OpenACCClauseCIREmitter final
54121
SourceLocation dirLoc;
55122

56123
llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
124+
// Keep track of the async-clause so that we can shortcut updating the data
125+
// operands async clauses.
126+
bool hasAsyncClause = false;
127+
// Keep track of the data operands so that we can update their async clauses.
128+
llvm::SmallVector<mlir::Operation *> dataOperands;
57129

58130
void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
59131
lastDeviceTypeValues.clear();
@@ -69,19 +141,8 @@ class OpenACCClauseCIREmitter final
69141
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
70142
}
71143

72-
mlir::Value createIntExpr(const Expr *intExpr) {
73-
mlir::Value expr = cgf.emitScalarExpr(intExpr);
74-
mlir::Location exprLoc = cgf.cgm.getLoc(intExpr->getBeginLoc());
75-
76-
mlir::IntegerType targetType = mlir::IntegerType::get(
77-
&cgf.getMLIRContext(), cgf.getContext().getIntWidth(intExpr->getType()),
78-
intExpr->getType()->isSignedIntegerOrEnumerationType()
79-
? mlir::IntegerType::SignednessSemantics::Signed
80-
: mlir::IntegerType::SignednessSemantics::Unsigned);
81-
82-
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
83-
exprLoc, targetType, expr);
84-
return conversionOp.getResult(0);
144+
mlir::Value emitOpenACCIntExpr(const Expr *intExpr) {
145+
return clang::emitOpenACCIntExpr(cgf, builder, intExpr);
85146
}
86147

87148
// 'condition' as an OpenACC grammar production is used for 'if' and (some
@@ -157,6 +218,104 @@ class OpenACCClauseCIREmitter final
157218
computeEmitter.Visit(&c);
158219
}
159220

221+
template <typename BeforeOpTy, typename AfterOpTy>
222+
void addDataOperand(const Expr *varOperand, mlir::acc::DataClause dataClause,
223+
bool structured, bool implicit) {
224+
DataOperandInfo opInfo =
225+
getDataOperandInfo(cgf, builder, dirKind, varOperand);
226+
mlir::ValueRange bounds;
227+
228+
// TODO: OpenACC: we should comprehend the 'modifier-list' here for the data
229+
// operand. At the moment, we don't have a uniform way to assign these
230+
// properly, and the dialect cannot represent anything other than 'readonly'
231+
// and 'zero' on copyin/copyout/create, so for now, we skip it.
232+
233+
auto beforeOp =
234+
builder.create<BeforeOpTy>(opInfo.beginLoc, opInfo.varValue, structured,
235+
implicit, opInfo.name, bounds);
236+
operation.getDataClauseOperandsMutable().append(beforeOp.getResult());
237+
238+
AfterOpTy afterOp;
239+
{
240+
mlir::OpBuilder::InsertionGuard guardCase(builder);
241+
builder.setInsertionPointAfter(operation);
242+
afterOp = builder.create<AfterOpTy>(opInfo.beginLoc, beforeOp.getResult(),
243+
opInfo.varValue, structured, implicit,
244+
opInfo.name, bounds);
245+
}
246+
247+
// Set the 'rest' of the info for both operations.
248+
beforeOp.setDataClause(dataClause);
249+
afterOp.setDataClause(dataClause);
250+
251+
// Make sure we record these, so 'async' values can be updated later.
252+
dataOperands.push_back(beforeOp.getOperation());
253+
dataOperands.push_back(afterOp.getOperation());
254+
}
255+
256+
// Helper function that covers for the fact that we don't have this function
257+
// on all operation types.
258+
mlir::ArrayAttr getAsyncOnlyAttr() {
259+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
260+
mlir::acc::KernelsOp, mlir::acc::DataOp>)
261+
return operation.getAsyncOnlyAttr();
262+
263+
// Note: 'wait' has async as well, but it cannot have data clauses, so we
264+
// don't have to handle them here.
265+
266+
llvm_unreachable("getting asyncOnly when clause not valid on operation?");
267+
}
268+
269+
// Helper function that covers for the fact that we don't have this function
270+
// on all operation types.
271+
mlir::ArrayAttr getAsyncOperandsDeviceTypeAttr() {
272+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
273+
mlir::acc::KernelsOp, mlir::acc::DataOp>)
274+
return operation.getAsyncOperandsDeviceTypeAttr();
275+
276+
// Note: 'wait' has async as well, but it cannot have data clauses, so we
277+
// don't have to handle them here.
278+
279+
llvm_unreachable(
280+
"getting asyncOperandsDeviceType when clause not valid on operation?");
281+
}
282+
283+
// Helper function that covers for the fact that we don't have this function
284+
// on all operation types.
285+
mlir::OperandRange getAsyncOperands() {
286+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
287+
mlir::acc::KernelsOp, mlir::acc::DataOp>)
288+
return operation.getAsyncOperands();
289+
290+
// Note: 'wait' has async as well, but it cannot have data clauses, so we
291+
// don't have to handle them here.
292+
293+
llvm_unreachable(
294+
"getting asyncOperandsDeviceType when clause not valid on operation?");
295+
}
296+
297+
// The 'data' clauses all require that we add the 'async' values from the
298+
// operation to them. We've collected the data operands along the way, so use
299+
// that list to get the current 'async' values.
300+
void updateDataOperandAsyncValues() {
301+
if (!hasAsyncClause || dataOperands.empty())
302+
return;
303+
304+
// TODO: OpenACC: Handle this correctly for combined constructs.
305+
306+
for (mlir::Operation *dataOp : dataOperands) {
307+
llvm::TypeSwitch<mlir::Operation *, void>(dataOp)
308+
.Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto op) {
309+
op.setAsyncOnlyAttr(getAsyncOnlyAttr());
310+
op.setAsyncOperandsDeviceTypeAttr(getAsyncOperandsDeviceTypeAttr());
311+
op.getAsyncOperandsMutable().assign(getAsyncOperands());
312+
})
313+
.Default([&](mlir::Operation *) {
314+
llvm_unreachable("Not a data operation?");
315+
});
316+
}
317+
}
318+
160319
public:
161320
OpenACCClauseCIREmitter(OpTy &operation, CIRGen::CIRGenFunction &cgf,
162321
CIRGen::CIRGenBuilderTy &builder,
@@ -168,6 +327,14 @@ class OpenACCClauseCIREmitter final
168327
clauseNotImplemented(clause);
169328
}
170329

330+
// The entry point for the CIR emitter. All users should use this rather than
331+
// 'visitClauseList', as this also handles the things that have to happen
332+
// 'after' the clauses are all visited.
333+
void emitClauses(ArrayRef<const OpenACCClause *> clauses) {
334+
this->VisitClauseList(clauses);
335+
updateDataOperandAsyncValues();
336+
}
337+
171338
void VisitDefaultClause(const OpenACCDefaultClause &clause) {
172339
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
173340
// operations listed in the rest of the arguments.
@@ -227,7 +394,7 @@ class OpenACCClauseCIREmitter final
227394
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
228395
mlir::acc::KernelsOp>) {
229396
operation.addNumWorkersOperand(builder.getContext(),
230-
createIntExpr(clause.getIntExpr()),
397+
emitOpenACCIntExpr(clause.getIntExpr()),
231398
lastDeviceTypeValues);
232399
} else if constexpr (isCombinedType<OpTy>) {
233400
applyToComputeOp(clause);
@@ -240,7 +407,7 @@ class OpenACCClauseCIREmitter final
240407
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
241408
mlir::acc::KernelsOp>) {
242409
operation.addVectorLengthOperand(builder.getContext(),
243-
createIntExpr(clause.getIntExpr()),
410+
emitOpenACCIntExpr(clause.getIntExpr()),
244411
lastDeviceTypeValues);
245412
} else if constexpr (isCombinedType<OpTy>) {
246413
applyToComputeOp(clause);
@@ -250,22 +417,34 @@ class OpenACCClauseCIREmitter final
250417
}
251418

252419
void VisitAsyncClause(const OpenACCAsyncClause &clause) {
420+
hasAsyncClause = true;
253421
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
254422
mlir::acc::KernelsOp, mlir::acc::DataOp>) {
255423
if (!clause.hasIntExpr())
256424
operation.addAsyncOnly(builder.getContext(), lastDeviceTypeValues);
257-
else
258-
operation.addAsyncOperand(builder.getContext(),
259-
createIntExpr(clause.getIntExpr()),
425+
else {
426+
427+
mlir::Value intExpr;
428+
{
429+
// Async int exprs can be referenced by the data operands, which means
430+
// that the int-exprs have to appear before them. IF there is a data
431+
// operand already, set the insertion point to 'before' it.
432+
mlir::OpBuilder::InsertionGuard guardCase(builder);
433+
if (!dataOperands.empty())
434+
builder.setInsertionPoint(dataOperands.front());
435+
intExpr = emitOpenACCIntExpr(clause.getIntExpr());
436+
}
437+
operation.addAsyncOperand(builder.getContext(), intExpr,
260438
lastDeviceTypeValues);
439+
}
261440
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::WaitOp>) {
262441
// Wait doesn't have a device_type, so its handling here is slightly
263442
// different.
264443
if (!clause.hasIntExpr())
265444
operation.setAsync(true);
266445
else
267446
operation.getAsyncOperandMutable().append(
268-
createIntExpr(clause.getIntExpr()));
447+
emitOpenACCIntExpr(clause.getIntExpr()));
269448
} else if constexpr (isCombinedType<OpTy>) {
270449
applyToComputeOp(clause);
271450
} else {
@@ -321,7 +500,7 @@ class OpenACCClauseCIREmitter final
321500
if constexpr (isOneOfTypes<OpTy, mlir::acc::InitOp, mlir::acc::ShutdownOp,
322501
mlir::acc::SetOp>) {
323502
operation.getDeviceNumMutable().append(
324-
createIntExpr(clause.getIntExpr()));
503+
emitOpenACCIntExpr(clause.getIntExpr()));
325504
} else {
326505
llvm_unreachable(
327506
"init, shutdown, set, are only valid device_num constructs");
@@ -333,7 +512,7 @@ class OpenACCClauseCIREmitter final
333512
mlir::acc::KernelsOp>) {
334513
llvm::SmallVector<mlir::Value> values;
335514
for (const Expr *E : clause.getIntExprs())
336-
values.push_back(createIntExpr(E));
515+
values.push_back(emitOpenACCIntExpr(E));
337516

338517
operation.addNumGangsOperands(builder.getContext(), values,
339518
lastDeviceTypeValues);
@@ -352,9 +531,9 @@ class OpenACCClauseCIREmitter final
352531
} else {
353532
llvm::SmallVector<mlir::Value> values;
354533
if (clause.hasDevNumExpr())
355-
values.push_back(createIntExpr(clause.getDevNumExpr()));
534+
values.push_back(emitOpenACCIntExpr(clause.getDevNumExpr()));
356535
for (const Expr *E : clause.getQueueIdExprs())
357-
values.push_back(createIntExpr(E));
536+
values.push_back(emitOpenACCIntExpr(E));
358537
operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(),
359538
values, lastDeviceTypeValues);
360539
}
@@ -370,7 +549,7 @@ class OpenACCClauseCIREmitter final
370549
void VisitDefaultAsyncClause(const OpenACCDefaultAsyncClause &clause) {
371550
if constexpr (isOneOfTypes<OpTy, mlir::acc::SetOp>) {
372551
operation.getDefaultAsyncMutable().append(
373-
createIntExpr(clause.getIntExpr()));
552+
emitOpenACCIntExpr(clause.getIntExpr()));
374553
} else {
375554
llvm_unreachable("set, is only valid device_num constructs");
376555
}
@@ -460,7 +639,7 @@ class OpenACCClauseCIREmitter final
460639
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
461640
if (clause.hasIntExpr())
462641
operation.addWorkerNumOperand(builder.getContext(),
463-
createIntExpr(clause.getIntExpr()),
642+
emitOpenACCIntExpr(clause.getIntExpr()),
464643
lastDeviceTypeValues);
465644
else
466645
operation.addEmptyWorker(builder.getContext(), lastDeviceTypeValues);
@@ -478,7 +657,7 @@ class OpenACCClauseCIREmitter final
478657
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
479658
if (clause.hasIntExpr())
480659
operation.addVectorOperand(builder.getContext(),
481-
createIntExpr(clause.getIntExpr()),
660+
emitOpenACCIntExpr(clause.getIntExpr()),
482661
lastDeviceTypeValues);
483662
else
484663
operation.addEmptyVector(builder.getContext(), lastDeviceTypeValues);
@@ -514,7 +693,7 @@ class OpenACCClauseCIREmitter final
514693
} else if (isa<OpenACCAsteriskSizeExpr>(expr)) {
515694
values.push_back(createConstantInt(exprLoc, 64, -1));
516695
} else {
517-
values.push_back(createIntExpr(expr));
696+
values.push_back(emitOpenACCIntExpr(expr));
518697
}
519698
}
520699

@@ -527,6 +706,20 @@ class OpenACCClauseCIREmitter final
527706
llvm_unreachable("Unknown construct kind in VisitGangClause");
528707
}
529708
}
709+
710+
void VisitCopyClause(const OpenACCCopyClause &clause) {
711+
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
712+
mlir::acc::KernelsOp>) {
713+
for (auto var : clause.getVarList())
714+
addDataOperand<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
715+
var, mlir::acc::DataClause::acc_copy, /*structured=*/true,
716+
/*implicit=*/false);
717+
} else {
718+
// TODO: When we've implemented this for everything, switch this to an
719+
// unreachable. data, declare, combined constructs remain.
720+
return clauseNotImplemented(clause);
721+
}
722+
}
530723
};
531724

532725
template <typename OpTy>

0 commit comments

Comments
 (0)