Skip to content

Commit e944455

Browse files
committed
[flang][openacc] Lower parallel construct
This patch upstream the lowering of Parallel construct that was initially done in flang-compiler#460. Reviewed By: jeanPerier Differential Revision: https://reviews.llvm.org/D88917
1 parent 726a6e8 commit e944455

File tree

1 file changed

+208
-1
lines changed

1 file changed

+208
-1
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 208 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,26 @@ static void genObjectList(const Fortran::parser::AccObjectList &objectList,
4949
}
5050
}
5151

52+
template <typename Clause>
53+
static void
54+
genObjectListWithModifier(const Clause *x,
55+
Fortran::lower::AbstractConverter &converter,
56+
Fortran::parser::AccDataModifier::Modifier mod,
57+
SmallVectorImpl<Value> &operandsWithModifier,
58+
SmallVectorImpl<Value> &operands) {
59+
const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
60+
const Fortran::parser::AccObjectList &accObjectList =
61+
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
62+
const auto &modifier =
63+
std::get<std::optional<Fortran::parser::AccDataModifier>>(
64+
listWithModifier.t);
65+
if (modifier && (*modifier).v == mod) {
66+
genObjectList(accObjectList, converter, operandsWithModifier);
67+
} else {
68+
genObjectList(accObjectList, converter, operands);
69+
}
70+
}
71+
5272
static void addOperands(SmallVectorImpl<Value> &operands,
5373
SmallVectorImpl<int32_t> &operandSegments,
5474
const SmallVectorImpl<Value> &clauseOperands) {
@@ -228,6 +248,193 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
228248
}
229249
}
230250

251+
static void
252+
genACCParallelOp(Fortran::lower::AbstractConverter &converter,
253+
const Fortran::parser::AccClauseList &accClauseList) {
254+
mlir::Value async;
255+
mlir::Value numGangs;
256+
mlir::Value numWorkers;
257+
mlir::Value vectorLength;
258+
mlir::Value ifCond;
259+
mlir::Value selfCond;
260+
SmallVector<Value, 2> waitOperands, reductionOperands, copyOperands,
261+
copyinOperands, copyinReadonlyOperands, copyoutOperands,
262+
copyoutZeroOperands, createOperands, createZeroOperands, noCreateOperands,
263+
presentOperands, devicePtrOperands, attachOperands, privateOperands,
264+
firstprivateOperands;
265+
266+
// Async, wait and self clause have optional values but can be present with
267+
// no value as well. When there is no value, the op has an attribute to
268+
// represent the clause.
269+
bool addAsyncAttr = false;
270+
bool addWaitAttr = false;
271+
bool addSelfAttr = false;
272+
273+
auto &firOpBuilder = converter.getFirOpBuilder();
274+
auto currentLocation = converter.getCurrentLocation();
275+
276+
// Lower clauses values mapped to operands.
277+
// Keep track of each group of operands separatly as clauses can appear
278+
// more than once.
279+
for (const auto &clause : accClauseList.v) {
280+
if (const auto *asyncClause =
281+
std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
282+
const auto &asyncClauseValue = asyncClause->v;
283+
if (asyncClauseValue) { // async has a value.
284+
async = fir::getBase(converter.genExprValue(
285+
*Fortran::semantics::GetExpr(*asyncClauseValue)));
286+
} else {
287+
addAsyncAttr = true;
288+
}
289+
} else if (const auto *waitClause =
290+
std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
291+
const auto &waitClauseValue = waitClause->v;
292+
if (waitClauseValue) { // wait has a value.
293+
const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
294+
const std::list<Fortran::parser::ScalarIntExpr> &waitList =
295+
std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
296+
for (const Fortran::parser::ScalarIntExpr &value : waitList) {
297+
Value v = fir::getBase(
298+
converter.genExprValue(*Fortran::semantics::GetExpr(value)));
299+
waitOperands.push_back(v);
300+
}
301+
} else {
302+
addWaitAttr = true;
303+
}
304+
} else if (const auto *numGangsClause =
305+
std::get_if<Fortran::parser::AccClause::NumGangs>(
306+
&clause.u)) {
307+
numGangs = fir::getBase(converter.genExprValue(
308+
*Fortran::semantics::GetExpr(numGangsClause->v)));
309+
} else if (const auto *numWorkersClause =
310+
std::get_if<Fortran::parser::AccClause::NumWorkers>(
311+
&clause.u)) {
312+
numWorkers = fir::getBase(converter.genExprValue(
313+
*Fortran::semantics::GetExpr(numWorkersClause->v)));
314+
} else if (const auto *vectorLengthClause =
315+
std::get_if<Fortran::parser::AccClause::VectorLength>(
316+
&clause.u)) {
317+
vectorLength = fir::getBase(converter.genExprValue(
318+
*Fortran::semantics::GetExpr(vectorLengthClause->v)));
319+
} else if (const auto *ifClause =
320+
std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
321+
Value cond = fir::getBase(
322+
converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v)));
323+
ifCond = firOpBuilder.createConvert(currentLocation,
324+
firOpBuilder.getI1Type(), cond);
325+
} else if (const auto *selfClause =
326+
std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) {
327+
if (selfClause->v) {
328+
Value cond = fir::getBase(converter.genExprValue(
329+
*Fortran::semantics::GetExpr(*(selfClause->v))));
330+
selfCond = firOpBuilder.createConvert(currentLocation,
331+
firOpBuilder.getI1Type(), cond);
332+
} else {
333+
addSelfAttr = true;
334+
}
335+
} else if (const auto *copyClause =
336+
std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) {
337+
genObjectList(copyClause->v, converter, copyOperands);
338+
} else if (const auto *copyinClause =
339+
std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
340+
genObjectListWithModifier<Fortran::parser::AccClause::Copyin>(
341+
copyinClause, converter,
342+
Fortran::parser::AccDataModifier::Modifier::ReadOnly,
343+
copyinReadonlyOperands, copyinOperands);
344+
} else if (const auto *copyoutClause =
345+
std::get_if<Fortran::parser::AccClause::Copyout>(
346+
&clause.u)) {
347+
genObjectListWithModifier<Fortran::parser::AccClause::Copyout>(
348+
copyoutClause, converter,
349+
Fortran::parser::AccDataModifier::Modifier::Zero, copyoutZeroOperands,
350+
copyoutOperands);
351+
} else if (const auto *createClause =
352+
std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
353+
genObjectListWithModifier<Fortran::parser::AccClause::Create>(
354+
createClause, converter,
355+
Fortran::parser::AccDataModifier::Modifier::Zero, createZeroOperands,
356+
createOperands);
357+
} else if (const auto *noCreateClause =
358+
std::get_if<Fortran::parser::AccClause::NoCreate>(
359+
&clause.u)) {
360+
genObjectList(noCreateClause->v, converter, noCreateOperands);
361+
} else if (const auto *presentClause =
362+
std::get_if<Fortran::parser::AccClause::Present>(
363+
&clause.u)) {
364+
genObjectList(presentClause->v, converter, presentOperands);
365+
} else if (const auto *devicePtrClause =
366+
std::get_if<Fortran::parser::AccClause::Deviceptr>(
367+
&clause.u)) {
368+
genObjectList(devicePtrClause->v, converter, devicePtrOperands);
369+
} else if (const auto *attachClause =
370+
std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
371+
genObjectList(attachClause->v, converter, attachOperands);
372+
} else if (const auto *privateClause =
373+
std::get_if<Fortran::parser::AccClause::Private>(
374+
&clause.u)) {
375+
genObjectList(privateClause->v, converter, privateOperands);
376+
} else if (const auto *firstprivateClause =
377+
std::get_if<Fortran::parser::AccClause::Firstprivate>(
378+
&clause.u)) {
379+
genObjectList(firstprivateClause->v, converter, firstprivateOperands);
380+
}
381+
}
382+
383+
// Prepare the operand segement size attribute and the operands value range.
384+
SmallVector<Value, 8> operands;
385+
SmallVector<int32_t, 8> operandSegments;
386+
addOperand(operands, operandSegments, async);
387+
addOperands(operands, operandSegments, waitOperands);
388+
addOperand(operands, operandSegments, numGangs);
389+
addOperand(operands, operandSegments, numWorkers);
390+
addOperand(operands, operandSegments, vectorLength);
391+
addOperand(operands, operandSegments, ifCond);
392+
addOperand(operands, operandSegments, selfCond);
393+
addOperands(operands, operandSegments, reductionOperands);
394+
addOperands(operands, operandSegments, copyOperands);
395+
addOperands(operands, operandSegments, copyinOperands);
396+
addOperands(operands, operandSegments, copyinReadonlyOperands);
397+
addOperands(operands, operandSegments, copyoutOperands);
398+
addOperands(operands, operandSegments, copyoutZeroOperands);
399+
addOperands(operands, operandSegments, createOperands);
400+
addOperands(operands, operandSegments, createZeroOperands);
401+
addOperands(operands, operandSegments, noCreateOperands);
402+
addOperands(operands, operandSegments, presentOperands);
403+
addOperands(operands, operandSegments, devicePtrOperands);
404+
addOperands(operands, operandSegments, attachOperands);
405+
addOperands(operands, operandSegments, privateOperands);
406+
addOperands(operands, operandSegments, firstprivateOperands);
407+
408+
auto parallelOp = createRegionOp<mlir::acc::ParallelOp, mlir::acc::YieldOp>(
409+
firOpBuilder, currentLocation, operands, operandSegments);
410+
411+
if (addAsyncAttr)
412+
parallelOp.setAttr(mlir::acc::ParallelOp::getAsyncAttrName(),
413+
firOpBuilder.getUnitAttr());
414+
if (addWaitAttr)
415+
parallelOp.setAttr(mlir::acc::ParallelOp::getWaitAttrName(),
416+
firOpBuilder.getUnitAttr());
417+
if (addSelfAttr)
418+
parallelOp.setAttr(mlir::acc::ParallelOp::getSelfAttrName(),
419+
firOpBuilder.getUnitAttr());
420+
}
421+
422+
static void
423+
genACC(Fortran::lower::AbstractConverter &converter,
424+
Fortran::lower::pft::Evaluation &eval,
425+
const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
426+
const auto &beginBlockDirective =
427+
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
428+
const auto &blockDirective =
429+
std::get<Fortran::parser::AccBlockDirective>(beginBlockDirective.t);
430+
const auto &accClauseList =
431+
std::get<Fortran::parser::AccClauseList>(beginBlockDirective.t);
432+
433+
if (blockDirective.v == llvm::acc::ACCD_parallel) {
434+
genACCParallelOp(converter, accClauseList);
435+
}
436+
}
437+
231438
void Fortran::lower::genOpenACCConstruct(
232439
Fortran::lower::AbstractConverter &converter,
233440
Fortran::lower::pft::Evaluation &eval,
@@ -236,7 +443,7 @@ void Fortran::lower::genOpenACCConstruct(
236443
std::visit(
237444
common::visitors{
238445
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
239-
TODO();
446+
genACC(converter, eval, blockConstruct);
240447
},
241448
[&](const Fortran::parser::OpenACCCombinedConstruct
242449
&combinedConstruct) { TODO(); },

0 commit comments

Comments
 (0)