Skip to content

Commit 8d34a32

Browse files
committed
Support taskwait with dependencies in/out/inout
Closes llvm#15
1 parent dacbd94 commit 8d34a32

File tree

13 files changed

+297
-163
lines changed

13 files changed

+297
-163
lines changed

clang/include/clang/AST/StmtOmpSs.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,32 +236,39 @@ class OSSTaskwaitDirective : public OSSExecutableDirective {
236236
///
237237
/// \param StartLoc Starting location of the directive kind.
238238
/// \param EndLoc Ending location of the directive.
239+
/// \param NumClauses Number of clauses.
239240
///
240-
OSSTaskwaitDirective(SourceLocation StartLoc, SourceLocation EndLoc)
241+
OSSTaskwaitDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumClauses)
241242
: OSSExecutableDirective(this, OSSTaskwaitDirectiveClass, OSSD_taskwait,
242-
StartLoc, EndLoc, 0, 0) {}
243+
StartLoc, EndLoc, NumClauses, 0) {}
243244

244245
/// Build an empty directive.
245246
///
246-
explicit OSSTaskwaitDirective()
247+
/// \param NumClauses Number of clauses.
248+
///
249+
explicit OSSTaskwaitDirective(unsigned NumClauses)
247250
: OSSExecutableDirective(this, OSSTaskwaitDirectiveClass, OSSD_taskwait,
248-
SourceLocation(), SourceLocation(), 0, 0) {}
251+
SourceLocation(), SourceLocation(), NumClauses, 0) {}
249252

250253
public:
251254
/// Creates directive.
252255
///
253256
/// \param C AST context.
254257
/// \param StartLoc Starting location of the directive kind.
255258
/// \param EndLoc Ending Location of the directive.
259+
/// \param Clauses List of clauses.
256260
///
257261
static OSSTaskwaitDirective *
258-
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc);
262+
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
263+
ArrayRef<OSSClause *> Clauses);
259264

260265
/// Creates an empty directive.
261266
///
262267
/// \param C AST context.
268+
/// \param NumClauses Number of clauses.
263269
///
264-
static OSSTaskwaitDirective *CreateEmpty(const ASTContext &C, EmptyShell);
270+
static OSSTaskwaitDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses,
271+
EmptyShell);
265272

266273
static bool classof(const Stmt *T) {
267274
return T->getStmtClass() == OSSTaskwaitDirectiveClass;

clang/include/clang/Basic/OmpSsKinds.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
#ifndef OMPSS_TASK_CLAUSE
2525
# define OMPSS_TASK_CLAUSE(Name)
2626
#endif
27+
#ifndef OMPSS_TASKWAIT_CLAUSE
28+
# define OMPSS_TASKWAIT_CLAUSE(Name)
29+
#endif
2730
#ifndef OMPSS_DECLARE_TASK_CLAUSE
2831
# define OMPSS_DECLARE_TASK_CLAUSE(Name)
2932
#endif
@@ -105,6 +108,12 @@ OMPSS_TASK_CLAUSE(weakinout)
105108
OMPSS_TASK_CLAUSE(weakcommutative)
106109
OMPSS_TASK_CLAUSE(weakreduction)
107110

111+
// Clauses allowed for OmpSs directive 'taskwait'.
112+
OMPSS_TASKWAIT_CLAUSE(depend)
113+
OMPSS_TASKWAIT_CLAUSE(in)
114+
OMPSS_TASKWAIT_CLAUSE(out)
115+
OMPSS_TASKWAIT_CLAUSE(inout)
116+
108117
// Clauses allowed for OmpSs directive 'task' declaration/outline.
109118
OMPSS_DECLARE_TASK_CLAUSE(if)
110119
OMPSS_DECLARE_TASK_CLAUSE(final)
@@ -128,4 +137,5 @@ OMPSS_DECLARE_TASK_CLAUSE(weakcommutative)
128137
#undef OMPSS_CLAUSE_ALIAS
129138
#undef OMPSS_CLAUSE
130139
#undef OMPSS_DECLARE_TASK_CLAUSE
140+
#undef OMPSS_TASKWAIT_CLAUSE
131141
#undef OMPSS_TASK_CLAUSE

clang/include/clang/Sema/Sema.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10882,7 +10882,8 @@ class Sema final {
1088210882
OmpSsDirectiveKind Kind, Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc);
1088310883

1088410884
/// Called on well-formed '\#pragma oss taskwait'.
10885-
StmtResult ActOnOmpSsTaskwaitDirective(SourceLocation StartLoc,
10885+
StmtResult ActOnOmpSsTaskwaitDirective(ArrayRef<OSSClause *> Clauses,
10886+
SourceLocation StartLoc,
1088610887
SourceLocation EndLoc);
1088710888

1088810889
/// Called on well-formed '\#pragma omp task' after parsing of the

clang/lib/AST/StmtOmpSs.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,23 @@ void OSSExecutableDirective::setClauses(ArrayRef<OSSClause *> Clauses) {
2323
std::copy(Clauses.begin(), Clauses.end(), getClauses().begin());
2424
}
2525

26-
OSSTaskwaitDirective *OSSTaskwaitDirective::Create(const ASTContext &C,
27-
SourceLocation StartLoc,
28-
SourceLocation EndLoc) {
29-
void *Mem = C.Allocate(sizeof(OSSTaskwaitDirective));
30-
OSSTaskwaitDirective *Dir = new (Mem) OSSTaskwaitDirective(StartLoc, EndLoc);
26+
OSSTaskwaitDirective *
27+
OSSTaskwaitDirective::Create(const ASTContext &C, SourceLocation StartLoc,
28+
SourceLocation EndLoc, ArrayRef<OSSClause *> Clauses) {
29+
unsigned Size = llvm::alignTo(sizeof(OSSTaskwaitDirective), alignof(OSSClause *));
30+
void *Mem = C.Allocate(Size + sizeof(OSSClause *) * Clauses.size());
31+
OSSTaskwaitDirective *Dir =
32+
new (Mem) OSSTaskwaitDirective(StartLoc, EndLoc, Clauses.size());
33+
Dir->setClauses(Clauses);
3134
return Dir;
3235
}
3336

3437
OSSTaskwaitDirective *OSSTaskwaitDirective::CreateEmpty(const ASTContext &C,
38+
unsigned NumClauses,
3539
EmptyShell) {
36-
void *Mem = C.Allocate(sizeof(OSSTaskwaitDirective));
37-
return new (Mem) OSSTaskwaitDirective();
40+
unsigned Size = llvm::alignTo(sizeof(OSSTaskwaitDirective), alignof(OSSClause *));
41+
void *Mem = C.Allocate(Size + sizeof(OSSClause *) * NumClauses);
42+
return new (Mem) OSSTaskwaitDirective(NumClauses);
3843
}
3944

4045
OSSTaskDirective *

clang/lib/Basic/OmpSsKinds.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,15 @@ bool clang::isAllowedClauseForDirective(OmpSsDirectiveKind DKind,
181181
}
182182
break;
183183
case OSSD_taskwait:
184+
switch (CKind) {
185+
#define OMPSS_TASKWAIT_CLAUSE(Name) \
186+
case OSSC_##Name: \
187+
return true;
188+
#include "clang/Basic/OmpSsKinds.def"
189+
default:
190+
break;
191+
}
192+
break;
184193
case OSSD_declare_reduction:
185194
case OSSD_unknown:
186195
break;

clang/lib/CodeGen/CGOmpSsRuntime.cpp

Lines changed: 123 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,16 +1819,48 @@ void CGOmpSsRuntime::EmitReduction(
18191819
}
18201820

18211821
void CGOmpSsRuntime::emitTaskwaitCall(CodeGenFunction &CGF,
1822-
SourceLocation Loc) {
1823-
llvm::Function *Callee = CGM.getIntrinsic(llvm::Intrinsic::directive_marker);
1824-
CGF.Builder.CreateCall(
1825-
Callee, {},
1826-
{
1827-
llvm::OperandBundleDef(
1828-
std::string(getBundleStr(OSSB_directive)),
1829-
llvm::ConstantDataArray::getString(
1830-
CGM.getLLVMContext(), getBundleStr(OSSB_taskwait)))
1831-
});
1822+
SourceLocation Loc,
1823+
const OSSTaskDataTy &Data) {
1824+
if (Data.empty()) {
1825+
// Regular taskwait
1826+
llvm::Function *Callee = CGM.getIntrinsic(llvm::Intrinsic::directive_marker);
1827+
CGF.Builder.CreateCall(
1828+
Callee, {},
1829+
{
1830+
llvm::OperandBundleDef(
1831+
std::string(getBundleStr(OSSB_directive)),
1832+
llvm::ConstantDataArray::getString(
1833+
CGM.getLLVMContext(), getBundleStr(OSSB_taskwait)))
1834+
});
1835+
} else {
1836+
// taskwait with deps -> task with deps if(0)
1837+
llvm::Function *EntryCallee = CGM.getIntrinsic(llvm::Intrinsic::directive_region_entry);
1838+
llvm::Function *ExitCallee = CGM.getIntrinsic(llvm::Intrinsic::directive_region_exit);
1839+
SmallVector<llvm::OperandBundleDef, 8> TaskInfo;
1840+
TaskInfo.emplace_back(
1841+
getBundleStr(OSSB_directive),
1842+
llvm::ConstantDataArray::getString(CGM.getLLVMContext(), getBundleStr(OSSB_task)));
1843+
1844+
// Add if(0) flag
1845+
llvm::Type *Int1Ty = CGF.ConvertType(CGF.getContext().BoolTy);
1846+
TaskInfo.emplace_back(getBundleStr(OSSB_if), llvm::ConstantInt::getSigned(Int1Ty, 0));
1847+
1848+
// Push Task Stack
1849+
TaskStack.push_back(TaskContext());
1850+
CaptureMapStack.push_back(CaptureMapTy());
1851+
1852+
InTaskEmission = true;
1853+
EmitTaskData(CGF, Data, TaskInfo);
1854+
InTaskEmission = false;
1855+
1856+
llvm::Instruction *Result =
1857+
CGF.Builder.CreateCall(EntryCallee, {}, llvm::makeArrayRef(TaskInfo));
1858+
CGF.Builder.CreateCall(ExitCallee, Result);
1859+
1860+
// Pop Task Stack
1861+
TaskStack.pop_back();
1862+
CaptureMapStack.pop_back();
1863+
}
18321864
}
18331865

18341866
// We're in task body context once we set InsertPt
@@ -1909,6 +1941,85 @@ static void EmitIfUsed(CodeGenFunction &CGF, llvm::BasicBlock *BB) {
19091941
delete BB;
19101942
}
19111943

1944+
void CGOmpSsRuntime::EmitTaskData(
1945+
CodeGenFunction &CGF,
1946+
const OSSTaskDataTy &Data,
1947+
SmallVectorImpl<llvm::OperandBundleDef> &TaskInfo) {
1948+
1949+
SmallVector<llvm::Value*, 4> CapturedList;
1950+
for (const Expr *E : Data.DSAs.Shareds) {
1951+
EmitDSAShared(CGF, E, TaskInfo, CapturedList);
1952+
}
1953+
for (const OSSDSAPrivateDataTy &PDataTy : Data.DSAs.Privates) {
1954+
EmitDSAPrivate(CGF, PDataTy, TaskInfo, CapturedList);
1955+
}
1956+
for (const OSSDSAFirstprivateDataTy &FpDataTy : Data.DSAs.Firstprivates) {
1957+
EmitDSAFirstprivate(CGF, FpDataTy, TaskInfo, CapturedList);
1958+
}
1959+
1960+
if (Data.Cost) {
1961+
llvm::Value *V = CGF.EmitScalarExpr(Data.Cost);
1962+
CapturedList.push_back(V);
1963+
TaskInfo.emplace_back(getBundleStr(OSSB_cost), V);
1964+
}
1965+
if (Data.Priority) {
1966+
llvm::Value *V = CGF.EmitScalarExpr(Data.Priority);
1967+
CapturedList.push_back(V);
1968+
TaskInfo.emplace_back(getBundleStr(OSSB_priority), V);
1969+
}
1970+
1971+
if (!CapturedList.empty())
1972+
TaskInfo.emplace_back(getBundleStr(OSSB_captured), CapturedList);
1973+
1974+
for (const OSSDepDataTy &Dep : Data.Deps.Ins) {
1975+
EmitDependency(getBundleStr(OSSB_in), CGF, Dep, TaskInfo);
1976+
}
1977+
for (const OSSDepDataTy &Dep : Data.Deps.Outs) {
1978+
EmitDependency(getBundleStr(OSSB_out), CGF, Dep, TaskInfo);
1979+
}
1980+
for (const OSSDepDataTy &Dep : Data.Deps.Inouts) {
1981+
EmitDependency(getBundleStr(OSSB_inout), CGF, Dep, TaskInfo);
1982+
}
1983+
for (const OSSDepDataTy &Dep : Data.Deps.Concurrents) {
1984+
EmitDependency(getBundleStr(OSSB_concurrent), CGF, Dep, TaskInfo);
1985+
}
1986+
for (const OSSDepDataTy &Dep : Data.Deps.Commutatives) {
1987+
EmitDependency(getBundleStr(OSSB_commutative), CGF, Dep, TaskInfo);
1988+
}
1989+
for (const OSSDepDataTy &Dep : Data.Deps.WeakIns) {
1990+
EmitDependency(getBundleStr(OSSB_weakin), CGF, Dep, TaskInfo);
1991+
}
1992+
for (const OSSDepDataTy &Dep : Data.Deps.WeakOuts) {
1993+
EmitDependency(getBundleStr(OSSB_weakout), CGF, Dep, TaskInfo);
1994+
}
1995+
for (const OSSDepDataTy &Dep : Data.Deps.WeakInouts) {
1996+
EmitDependency(getBundleStr(OSSB_weakinout), CGF, Dep, TaskInfo);
1997+
}
1998+
for (const OSSDepDataTy &Dep : Data.Deps.WeakConcurrents) {
1999+
EmitDependency(getBundleStr(OSSB_weakconcurrent), CGF, Dep, TaskInfo);
2000+
}
2001+
for (const OSSDepDataTy &Dep : Data.Deps.WeakCommutatives) {
2002+
EmitDependency(getBundleStr(OSSB_weakcommutative), CGF, Dep, TaskInfo);
2003+
}
2004+
for (const OSSReductionDataTy &Red : Data.Reductions.RedList) {
2005+
EmitReduction(getBundleStr(OSSB_reduction),
2006+
getBundleStr(OSSB_redinit),
2007+
getBundleStr(OSSB_redcomb),
2008+
CGF, Red, TaskInfo);
2009+
}
2010+
for (const OSSReductionDataTy &Red : Data.Reductions.WeakRedList) {
2011+
EmitReduction(getBundleStr(OSSB_weakreduction),
2012+
getBundleStr(OSSB_redinit),
2013+
getBundleStr(OSSB_redcomb),
2014+
CGF, Red, TaskInfo);
2015+
}
2016+
2017+
if (Data.If)
2018+
TaskInfo.emplace_back(getBundleStr(OSSB_if), CGF.EvaluateExprAsBool(Data.If));
2019+
if (Data.Final)
2020+
TaskInfo.emplace_back(getBundleStr(OSSB_final), CGF.EvaluateExprAsBool(Data.Final));
2021+
}
2022+
19122023
RValue CGOmpSsRuntime::emitTaskFunction(CodeGenFunction &CGF,
19132024
const FunctionDecl *FD,
19142025
const CallExpr *CE,
@@ -2215,90 +2326,17 @@ void CGOmpSsRuntime::emitTaskCall(CodeGenFunction &CGF,
22152326
getBundleStr(OSSB_directive),
22162327
llvm::ConstantDataArray::getString(CGM.getLLVMContext(), getBundleStr(OSSB_task)));
22172328

2218-
SmallVector<llvm::Value*, 4> CapturedList;
2219-
2329+
// Push Task Stack
22202330
TaskStack.push_back(TaskContext());
22212331
CaptureMapStack.push_back(CaptureMapTy());
22222332

22232333
InTaskEmission = true;
2224-
for (const Expr *E : Data.DSAs.Shareds) {
2225-
EmitDSAShared(CGF, E, TaskInfo, CapturedList);
2226-
}
2227-
for (const OSSDSAPrivateDataTy &PDataTy : Data.DSAs.Privates) {
2228-
EmitDSAPrivate(CGF, PDataTy, TaskInfo, CapturedList);
2229-
}
2230-
for (const OSSDSAFirstprivateDataTy &FpDataTy : Data.DSAs.Firstprivates) {
2231-
EmitDSAFirstprivate(CGF, FpDataTy, TaskInfo, CapturedList);
2232-
}
2233-
2234-
if (Data.Cost) {
2235-
llvm::Value *V = CGF.EmitScalarExpr(Data.Cost);
2236-
CapturedList.push_back(V);
2237-
TaskInfo.emplace_back(getBundleStr(OSSB_cost), V);
2238-
}
2239-
if (Data.Priority) {
2240-
llvm::Value *V = CGF.EmitScalarExpr(Data.Priority);
2241-
CapturedList.push_back(V);
2242-
TaskInfo.emplace_back(getBundleStr(OSSB_priority), V);
2243-
}
2244-
2245-
if (!CapturedList.empty())
2246-
TaskInfo.emplace_back(getBundleStr(OSSB_captured), CapturedList);
2247-
2248-
for (const OSSDepDataTy &Dep : Data.Deps.Ins) {
2249-
EmitDependency(getBundleStr(OSSB_in), CGF, Dep, TaskInfo);
2250-
}
2251-
for (const OSSDepDataTy &Dep : Data.Deps.Outs) {
2252-
EmitDependency(getBundleStr(OSSB_out), CGF, Dep, TaskInfo);
2253-
}
2254-
for (const OSSDepDataTy &Dep : Data.Deps.Inouts) {
2255-
EmitDependency(getBundleStr(OSSB_inout), CGF, Dep, TaskInfo);
2256-
}
2257-
for (const OSSDepDataTy &Dep : Data.Deps.Concurrents) {
2258-
EmitDependency(getBundleStr(OSSB_concurrent), CGF, Dep, TaskInfo);
2259-
}
2260-
for (const OSSDepDataTy &Dep : Data.Deps.Commutatives) {
2261-
EmitDependency(getBundleStr(OSSB_commutative), CGF, Dep, TaskInfo);
2262-
}
2263-
for (const OSSDepDataTy &Dep : Data.Deps.WeakIns) {
2264-
EmitDependency(getBundleStr(OSSB_weakin), CGF, Dep, TaskInfo);
2265-
}
2266-
for (const OSSDepDataTy &Dep : Data.Deps.WeakOuts) {
2267-
EmitDependency(getBundleStr(OSSB_weakout), CGF, Dep, TaskInfo);
2268-
}
2269-
for (const OSSDepDataTy &Dep : Data.Deps.WeakInouts) {
2270-
EmitDependency(getBundleStr(OSSB_weakinout), CGF, Dep, TaskInfo);
2271-
}
2272-
for (const OSSDepDataTy &Dep : Data.Deps.WeakConcurrents) {
2273-
EmitDependency(getBundleStr(OSSB_weakconcurrent), CGF, Dep, TaskInfo);
2274-
}
2275-
for (const OSSDepDataTy &Dep : Data.Deps.WeakCommutatives) {
2276-
EmitDependency(getBundleStr(OSSB_weakcommutative), CGF, Dep, TaskInfo);
2277-
}
2278-
for (const OSSReductionDataTy &Red : Data.Reductions.RedList) {
2279-
EmitReduction(getBundleStr(OSSB_reduction),
2280-
getBundleStr(OSSB_redinit),
2281-
getBundleStr(OSSB_redcomb),
2282-
CGF, Red, TaskInfo);
2283-
}
2284-
for (const OSSReductionDataTy &Red : Data.Reductions.WeakRedList) {
2285-
EmitReduction(getBundleStr(OSSB_weakreduction),
2286-
getBundleStr(OSSB_redinit),
2287-
getBundleStr(OSSB_redcomb),
2288-
CGF, Red, TaskInfo);
2289-
}
2290-
2291-
if (Data.If)
2292-
TaskInfo.emplace_back(getBundleStr(OSSB_if), CGF.EvaluateExprAsBool(Data.If));
2293-
if (Data.Final)
2294-
TaskInfo.emplace_back(getBundleStr(OSSB_final), CGF.EvaluateExprAsBool(Data.Final));
2295-
2334+
EmitTaskData(CGF, Data, TaskInfo);
22962335
InTaskEmission = false;
22972336

22982337
llvm::Instruction *Result =
22992338
CGF.Builder.CreateCall(EntryCallee, {}, llvm::makeArrayRef(TaskInfo));
23002339

2301-
// Push Task Stack
23022340
llvm::Value *Undef = llvm::UndefValue::get(CGF.Int32Ty);
23032341
llvm::Instruction *TaskAllocaInsertPt = new llvm::BitCastInst(Undef, CGF.Int32Ty, "taskallocapt", Result->getParent());
23042342
setTaskInsertPt(TaskAllocaInsertPt);

0 commit comments

Comments
 (0)