14
14
15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/Dialect/OpenACC/OpenACC.h"
17
+ #include " llvm/ADT/TypeSwitch.h"
17
18
namespace clang {
18
19
// Simple type-trait to see if the first template arg is one of the list, so we
19
20
// can tell whether to `if-constexpr` a bunch of stuff.
@@ -36,6 +37,72 @@ template <typename ToTest> constexpr bool isCombinedType = false;
36
37
template <typename T>
37
38
constexpr bool isCombinedType<CombinedConstructClauseInfo<T>> = true ;
38
39
40
+ namespace {
41
+ struct DataOperandInfo {
42
+ mlir::Location beginLoc;
43
+ mlir::Value varValue;
44
+ llvm::StringRef name;
45
+ };
46
+
47
+ inline mlir::Value emitOpenACCIntExpr (CIRGen::CIRGenFunction &cgf,
48
+ CIRGen::CIRGenBuilderTy &builder,
49
+ const Expr *intExpr) {
50
+ mlir::Value expr = cgf.emitScalarExpr (intExpr);
51
+ mlir::Location exprLoc = cgf.cgm .getLoc (intExpr->getBeginLoc ());
52
+
53
+ mlir::IntegerType targetType = mlir::IntegerType::get (
54
+ &cgf.getMLIRContext (), cgf.getContext ().getIntWidth (intExpr->getType ()),
55
+ intExpr->getType ()->isSignedIntegerOrEnumerationType ()
56
+ ? mlir::IntegerType::SignednessSemantics::Signed
57
+ : mlir::IntegerType::SignednessSemantics::Unsigned);
58
+
59
+ auto conversionOp = builder.create <mlir::UnrealizedConversionCastOp>(
60
+ exprLoc, targetType, expr);
61
+ return conversionOp.getResult (0 );
62
+ }
63
+
64
+ // A helper function that gets the information from an operand to a data
65
+ // clause, so that it can be used to emit the data operations.
66
+ inline DataOperandInfo getDataOperandInfo (CIRGen::CIRGenFunction &cgf,
67
+ CIRGen::CIRGenBuilderTy &builder,
68
+ OpenACCDirectiveKind dk,
69
+ const Expr *e) {
70
+ // TODO: OpenACC: Cache was different enough as to need a separate
71
+ // `ActOnCacheVar`, so we are going to need to do some investigations here
72
+ // when it comes to implement this for cache.
73
+ if (dk == OpenACCDirectiveKind::Cache) {
74
+ cgf.cgm .errorNYI (e->getSourceRange (),
75
+ " OpenACC data operand for 'cache' directive" );
76
+ return {cgf.cgm .getLoc (e->getBeginLoc ()), {}, {}};
77
+ }
78
+
79
+ const Expr *curVarExpr = e->IgnoreParenImpCasts ();
80
+
81
+ mlir::Location exprLoc = cgf.cgm .getLoc (curVarExpr->getBeginLoc ());
82
+
83
+ // TODO: OpenACC: Assemble the list of bounds.
84
+ if (isa<ArraySectionExpr, ArraySubscriptExpr>(curVarExpr)) {
85
+ cgf.cgm .errorNYI (curVarExpr->getSourceRange (),
86
+ " OpenACC data clause array subscript/section" );
87
+ return {exprLoc, {}, {}};
88
+ }
89
+
90
+ // TODO: OpenACC: if this is a member expr, emit the VarPtrPtr correctly.
91
+ if (isa<MemberExpr>(curVarExpr)) {
92
+ cgf.cgm .errorNYI (curVarExpr->getSourceRange (),
93
+ " OpenACC Data clause member expr" );
94
+ return {exprLoc, {}, {}};
95
+ }
96
+
97
+ // Sema has made sure that only 4 types of things can get here, array
98
+ // subscript, array section, member expr, or DRE to a var decl (or the former
99
+ // 3 wrapping a var-decl), so we should be able to assume this is right.
100
+ const auto *dre = cast<DeclRefExpr>(curVarExpr);
101
+ const auto *vd = cast<VarDecl>(dre->getFoundDecl ()->getCanonicalDecl ());
102
+ return {exprLoc, cgf.emitDeclRefLValue (dre).getPointer (), vd->getName ()};
103
+ }
104
+ } // namespace
105
+
39
106
template <typename OpTy>
40
107
class OpenACCClauseCIREmitter final
41
108
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter<OpTy>> {
@@ -54,6 +121,11 @@ class OpenACCClauseCIREmitter final
54
121
SourceLocation dirLoc;
55
122
56
123
llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
124
+ // Keep track of the async-clause so that we can shortcut updating the data
125
+ // operands async clauses.
126
+ bool hasAsyncClause = false ;
127
+ // Keep track of the data operands so that we can update their async clauses.
128
+ llvm::SmallVector<mlir::Operation *> dataOperands;
57
129
58
130
void setLastDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
59
131
lastDeviceTypeValues.clear ();
@@ -69,19 +141,8 @@ class OpenACCClauseCIREmitter final
69
141
cgf.cgm .errorNYI (c.getSourceRange (), " OpenACC Clause" , c.getClauseKind ());
70
142
}
71
143
72
- mlir::Value createIntExpr (const Expr *intExpr) {
73
- mlir::Value expr = cgf.emitScalarExpr (intExpr);
74
- mlir::Location exprLoc = cgf.cgm .getLoc (intExpr->getBeginLoc ());
75
-
76
- mlir::IntegerType targetType = mlir::IntegerType::get (
77
- &cgf.getMLIRContext (), cgf.getContext ().getIntWidth (intExpr->getType ()),
78
- intExpr->getType ()->isSignedIntegerOrEnumerationType ()
79
- ? mlir::IntegerType::SignednessSemantics::Signed
80
- : mlir::IntegerType::SignednessSemantics::Unsigned);
81
-
82
- auto conversionOp = builder.create <mlir::UnrealizedConversionCastOp>(
83
- exprLoc, targetType, expr);
84
- return conversionOp.getResult (0 );
144
+ mlir::Value emitOpenACCIntExpr (const Expr *intExpr) {
145
+ return clang::emitOpenACCIntExpr (cgf, builder, intExpr);
85
146
}
86
147
87
148
// 'condition' as an OpenACC grammar production is used for 'if' and (some
@@ -157,6 +218,104 @@ class OpenACCClauseCIREmitter final
157
218
computeEmitter.Visit (&c);
158
219
}
159
220
221
+ template <typename BeforeOpTy, typename AfterOpTy>
222
+ void addDataOperand (const Expr *varOperand, mlir::acc::DataClause dataClause,
223
+ bool structured, bool implicit) {
224
+ DataOperandInfo opInfo =
225
+ getDataOperandInfo (cgf, builder, dirKind, varOperand);
226
+ mlir::ValueRange bounds;
227
+
228
+ // TODO: OpenACC: we should comprehend the 'modifier-list' here for the data
229
+ // operand. At the moment, we don't have a uniform way to assign these
230
+ // properly, and the dialect cannot represent anything other than 'readonly'
231
+ // and 'zero' on copyin/copyout/create, so for now, we skip it.
232
+
233
+ auto beforeOp =
234
+ builder.create <BeforeOpTy>(opInfo.beginLoc , opInfo.varValue , structured,
235
+ implicit, opInfo.name , bounds);
236
+ operation.getDataClauseOperandsMutable ().append (beforeOp.getResult ());
237
+
238
+ AfterOpTy afterOp;
239
+ {
240
+ mlir::OpBuilder::InsertionGuard guardCase (builder);
241
+ builder.setInsertionPointAfter (operation);
242
+ afterOp = builder.create <AfterOpTy>(opInfo.beginLoc , beforeOp.getResult (),
243
+ opInfo.varValue , structured, implicit,
244
+ opInfo.name , bounds);
245
+ }
246
+
247
+ // Set the 'rest' of the info for both operations.
248
+ beforeOp.setDataClause (dataClause);
249
+ afterOp.setDataClause (dataClause);
250
+
251
+ // Make sure we record these, so 'async' values can be updated later.
252
+ dataOperands.push_back (beforeOp.getOperation ());
253
+ dataOperands.push_back (afterOp.getOperation ());
254
+ }
255
+
256
+ // Helper function that covers for the fact that we don't have this function
257
+ // on all operation types.
258
+ mlir::ArrayAttr getAsyncOnlyAttr () {
259
+ if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
260
+ mlir::acc::KernelsOp, mlir::acc::DataOp>)
261
+ return operation.getAsyncOnlyAttr ();
262
+
263
+ // Note: 'wait' has async as well, but it cannot have data clauses, so we
264
+ // don't have to handle them here.
265
+
266
+ llvm_unreachable (" getting asyncOnly when clause not valid on operation?" );
267
+ }
268
+
269
+ // Helper function that covers for the fact that we don't have this function
270
+ // on all operation types.
271
+ mlir::ArrayAttr getAsyncOperandsDeviceTypeAttr () {
272
+ if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
273
+ mlir::acc::KernelsOp, mlir::acc::DataOp>)
274
+ return operation.getAsyncOperandsDeviceTypeAttr ();
275
+
276
+ // Note: 'wait' has async as well, but it cannot have data clauses, so we
277
+ // don't have to handle them here.
278
+
279
+ llvm_unreachable (
280
+ " getting asyncOperandsDeviceType when clause not valid on operation?" );
281
+ }
282
+
283
+ // Helper function that covers for the fact that we don't have this function
284
+ // on all operation types.
285
+ mlir::OperandRange getAsyncOperands () {
286
+ if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
287
+ mlir::acc::KernelsOp, mlir::acc::DataOp>)
288
+ return operation.getAsyncOperands ();
289
+
290
+ // Note: 'wait' has async as well, but it cannot have data clauses, so we
291
+ // don't have to handle them here.
292
+
293
+ llvm_unreachable (
294
+ " getting asyncOperandsDeviceType when clause not valid on operation?" );
295
+ }
296
+
297
+ // The 'data' clauses all require that we add the 'async' values from the
298
+ // operation to them. We've collected the data operands along the way, so use
299
+ // that list to get the current 'async' values.
300
+ void updateDataOperandAsyncValues () {
301
+ if (!hasAsyncClause || dataOperands.empty ())
302
+ return ;
303
+
304
+ // TODO: OpenACC: Handle this correctly for combined constructs.
305
+
306
+ for (mlir::Operation *dataOp : dataOperands) {
307
+ llvm::TypeSwitch<mlir::Operation *, void >(dataOp)
308
+ .Case <ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto op) {
309
+ op.setAsyncOnlyAttr (getAsyncOnlyAttr ());
310
+ op.setAsyncOperandsDeviceTypeAttr (getAsyncOperandsDeviceTypeAttr ());
311
+ op.getAsyncOperandsMutable ().assign (getAsyncOperands ());
312
+ })
313
+ .Default ([&](mlir::Operation *) {
314
+ llvm_unreachable (" Not a data operation?" );
315
+ });
316
+ }
317
+ }
318
+
160
319
public:
161
320
OpenACCClauseCIREmitter (OpTy &operation, CIRGen::CIRGenFunction &cgf,
162
321
CIRGen::CIRGenBuilderTy &builder,
@@ -168,6 +327,14 @@ class OpenACCClauseCIREmitter final
168
327
clauseNotImplemented (clause);
169
328
}
170
329
330
+ // The entry point for the CIR emitter. All users should use this rather than
331
+ // 'visitClauseList', as this also handles the things that have to happen
332
+ // 'after' the clauses are all visited.
333
+ void emitClauses (ArrayRef<const OpenACCClause *> clauses) {
334
+ this ->VisitClauseList (clauses);
335
+ updateDataOperandAsyncValues ();
336
+ }
337
+
171
338
void VisitDefaultClause (const OpenACCDefaultClause &clause) {
172
339
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
173
340
// operations listed in the rest of the arguments.
@@ -227,7 +394,7 @@ class OpenACCClauseCIREmitter final
227
394
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
228
395
mlir::acc::KernelsOp>) {
229
396
operation.addNumWorkersOperand (builder.getContext (),
230
- createIntExpr (clause.getIntExpr ()),
397
+ emitOpenACCIntExpr (clause.getIntExpr ()),
231
398
lastDeviceTypeValues);
232
399
} else if constexpr (isCombinedType<OpTy>) {
233
400
applyToComputeOp (clause);
@@ -240,7 +407,7 @@ class OpenACCClauseCIREmitter final
240
407
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
241
408
mlir::acc::KernelsOp>) {
242
409
operation.addVectorLengthOperand (builder.getContext (),
243
- createIntExpr (clause.getIntExpr ()),
410
+ emitOpenACCIntExpr (clause.getIntExpr ()),
244
411
lastDeviceTypeValues);
245
412
} else if constexpr (isCombinedType<OpTy>) {
246
413
applyToComputeOp (clause);
@@ -250,22 +417,34 @@ class OpenACCClauseCIREmitter final
250
417
}
251
418
252
419
void VisitAsyncClause (const OpenACCAsyncClause &clause) {
420
+ hasAsyncClause = true ;
253
421
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
254
422
mlir::acc::KernelsOp, mlir::acc::DataOp>) {
255
423
if (!clause.hasIntExpr ())
256
424
operation.addAsyncOnly (builder.getContext (), lastDeviceTypeValues);
257
- else
258
- operation.addAsyncOperand (builder.getContext (),
259
- createIntExpr (clause.getIntExpr ()),
425
+ else {
426
+
427
+ mlir::Value intExpr;
428
+ {
429
+ // Async int exprs can be referenced by the data operands, which means
430
+ // that the int-exprs have to appear before them. IF there is a data
431
+ // operand already, set the insertion point to 'before' it.
432
+ mlir::OpBuilder::InsertionGuard guardCase (builder);
433
+ if (!dataOperands.empty ())
434
+ builder.setInsertionPoint (dataOperands.front ());
435
+ intExpr = emitOpenACCIntExpr (clause.getIntExpr ());
436
+ }
437
+ operation.addAsyncOperand (builder.getContext (), intExpr,
260
438
lastDeviceTypeValues);
439
+ }
261
440
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::WaitOp>) {
262
441
// Wait doesn't have a device_type, so its handling here is slightly
263
442
// different.
264
443
if (!clause.hasIntExpr ())
265
444
operation.setAsync (true );
266
445
else
267
446
operation.getAsyncOperandMutable ().append (
268
- createIntExpr (clause.getIntExpr ()));
447
+ emitOpenACCIntExpr (clause.getIntExpr ()));
269
448
} else if constexpr (isCombinedType<OpTy>) {
270
449
applyToComputeOp (clause);
271
450
} else {
@@ -321,7 +500,7 @@ class OpenACCClauseCIREmitter final
321
500
if constexpr (isOneOfTypes<OpTy, mlir::acc::InitOp, mlir::acc::ShutdownOp,
322
501
mlir::acc::SetOp>) {
323
502
operation.getDeviceNumMutable ().append (
324
- createIntExpr (clause.getIntExpr ()));
503
+ emitOpenACCIntExpr (clause.getIntExpr ()));
325
504
} else {
326
505
llvm_unreachable (
327
506
" init, shutdown, set, are only valid device_num constructs" );
@@ -333,7 +512,7 @@ class OpenACCClauseCIREmitter final
333
512
mlir::acc::KernelsOp>) {
334
513
llvm::SmallVector<mlir::Value> values;
335
514
for (const Expr *E : clause.getIntExprs ())
336
- values.push_back (createIntExpr (E));
515
+ values.push_back (emitOpenACCIntExpr (E));
337
516
338
517
operation.addNumGangsOperands (builder.getContext (), values,
339
518
lastDeviceTypeValues);
@@ -352,9 +531,9 @@ class OpenACCClauseCIREmitter final
352
531
} else {
353
532
llvm::SmallVector<mlir::Value> values;
354
533
if (clause.hasDevNumExpr ())
355
- values.push_back (createIntExpr (clause.getDevNumExpr ()));
534
+ values.push_back (emitOpenACCIntExpr (clause.getDevNumExpr ()));
356
535
for (const Expr *E : clause.getQueueIdExprs ())
357
- values.push_back (createIntExpr (E));
536
+ values.push_back (emitOpenACCIntExpr (E));
358
537
operation.addWaitOperands (builder.getContext (), clause.hasDevNumExpr (),
359
538
values, lastDeviceTypeValues);
360
539
}
@@ -370,7 +549,7 @@ class OpenACCClauseCIREmitter final
370
549
void VisitDefaultAsyncClause (const OpenACCDefaultAsyncClause &clause) {
371
550
if constexpr (isOneOfTypes<OpTy, mlir::acc::SetOp>) {
372
551
operation.getDefaultAsyncMutable ().append (
373
- createIntExpr (clause.getIntExpr ()));
552
+ emitOpenACCIntExpr (clause.getIntExpr ()));
374
553
} else {
375
554
llvm_unreachable (" set, is only valid device_num constructs" );
376
555
}
@@ -460,7 +639,7 @@ class OpenACCClauseCIREmitter final
460
639
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
461
640
if (clause.hasIntExpr ())
462
641
operation.addWorkerNumOperand (builder.getContext (),
463
- createIntExpr (clause.getIntExpr ()),
642
+ emitOpenACCIntExpr (clause.getIntExpr ()),
464
643
lastDeviceTypeValues);
465
644
else
466
645
operation.addEmptyWorker (builder.getContext (), lastDeviceTypeValues);
@@ -478,7 +657,7 @@ class OpenACCClauseCIREmitter final
478
657
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
479
658
if (clause.hasIntExpr ())
480
659
operation.addVectorOperand (builder.getContext (),
481
- createIntExpr (clause.getIntExpr ()),
660
+ emitOpenACCIntExpr (clause.getIntExpr ()),
482
661
lastDeviceTypeValues);
483
662
else
484
663
operation.addEmptyVector (builder.getContext (), lastDeviceTypeValues);
@@ -514,7 +693,7 @@ class OpenACCClauseCIREmitter final
514
693
} else if (isa<OpenACCAsteriskSizeExpr>(expr)) {
515
694
values.push_back (createConstantInt (exprLoc, 64 , -1 ));
516
695
} else {
517
- values.push_back (createIntExpr (expr));
696
+ values.push_back (emitOpenACCIntExpr (expr));
518
697
}
519
698
}
520
699
@@ -527,6 +706,20 @@ class OpenACCClauseCIREmitter final
527
706
llvm_unreachable (" Unknown construct kind in VisitGangClause" );
528
707
}
529
708
}
709
+
710
+ void VisitCopyClause (const OpenACCCopyClause &clause) {
711
+ if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
712
+ mlir::acc::KernelsOp>) {
713
+ for (auto var : clause.getVarList ())
714
+ addDataOperand<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
715
+ var, mlir::acc::DataClause::acc_copy, /* structured=*/ true ,
716
+ /* implicit=*/ false );
717
+ } else {
718
+ // TODO: When we've implemented this for everything, switch this to an
719
+ // unreachable. data, declare, combined constructs remain.
720
+ return clauseNotImplemented (clause);
721
+ }
722
+ }
530
723
};
531
724
532
725
template <typename OpTy>
0 commit comments