Skip to content

Commit 8286743

Browse files
authored
[flang][openacc] Allow acc routine at the top level (llvm#69936)
Some compilers allow the `$acc routine(<name>)` to be placed at the program unit level. To be compatible, this patch enables the use of acc routine at this level. These acc routine directives must have a name.
1 parent 93f8e52 commit 8286743

15 files changed

+150
-28
lines changed

flang/docs/OpenACC.md

+1
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ local:
2323
warning instead of an error as other compiler accepts it.
2424
* The `if` clause accepts scalar integer expression in addition to scalar
2525
logical expression.
26+
* `!$acc routine` directive can be placed at the top level.

flang/include/flang/Lower/OpenACC.h

+6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ namespace Fortran {
3737
namespace parser {
3838
struct OpenACCConstruct;
3939
struct OpenACCDeclarativeConstruct;
40+
struct OpenACCRoutineConstruct;
4041
} // namespace parser
4142

4243
namespace semantics {
@@ -71,6 +72,11 @@ void genOpenACCDeclarativeConstruct(AbstractConverter &,
7172
StatementContext &,
7273
const parser::OpenACCDeclarativeConstruct &,
7374
AccRoutineInfoMappingList &);
75+
void genOpenACCRoutineConstruct(AbstractConverter &,
76+
Fortran::semantics::SemanticsContext &,
77+
mlir::ModuleOp &,
78+
const parser::OpenACCRoutineConstruct &,
79+
AccRoutineInfoMappingList &);
7480

7581
void finalizeOpenACCRoutineAttachment(mlir::ModuleOp &,
7682
AccRoutineInfoMappingList &);

flang/include/flang/Lower/PFTBuilder.h

+14-2
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ using Constructs =
135135

136136
using Directives =
137137
std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
138+
parser::OpenACCRoutineConstruct,
138139
parser::OpenACCDeclarativeConstruct, parser::OpenMPConstruct,
139140
parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective>;
140141

@@ -360,7 +361,8 @@ using ProgramVariant =
360361
ReferenceVariant<parser::MainProgram, parser::FunctionSubprogram,
361362
parser::SubroutineSubprogram, parser::Module,
362363
parser::Submodule, parser::SeparateModuleSubprogram,
363-
parser::BlockData, parser::CompilerDirective>;
364+
parser::BlockData, parser::CompilerDirective,
365+
parser::OpenACCRoutineConstruct>;
364366
/// A program is a list of program units.
365367
/// These units can be function like, module like, or block data.
366368
struct ProgramUnit : ProgramVariant {
@@ -763,10 +765,20 @@ struct CompilerDirectiveUnit : public ProgramUnit {
763765
CompilerDirectiveUnit(const CompilerDirectiveUnit &) = delete;
764766
};
765767

768+
// Top level OpenACC routine directives
769+
struct OpenACCDirectiveUnit : public ProgramUnit {
770+
OpenACCDirectiveUnit(const parser::OpenACCRoutineConstruct &directive,
771+
const PftNode &parent)
772+
: ProgramUnit{directive, parent}, routine{directive} {};
773+
OpenACCDirectiveUnit(OpenACCDirectiveUnit &&) = default;
774+
OpenACCDirectiveUnit(const OpenACCDirectiveUnit &) = delete;
775+
const parser::OpenACCRoutineConstruct &routine;
776+
};
777+
766778
/// A Program is the top-level root of the PFT.
767779
struct Program {
768780
using Units = std::variant<FunctionLikeUnit, ModuleLikeUnit, BlockDataUnit,
769-
CompilerDirectiveUnit>;
781+
CompilerDirectiveUnit, OpenACCDirectiveUnit>;
770782

771783
Program(semantics::CommonBlockList &&commonBlocks)
772784
: commonBlocks{std::move(commonBlocks)} {}

flang/include/flang/Parser/parse-tree.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ struct PauseStmt;
262262
struct OpenACCConstruct;
263263
struct AccEndCombinedDirective;
264264
struct OpenACCDeclarativeConstruct;
265+
struct OpenACCRoutineConstruct;
265266
struct OpenMPConstruct;
266267
struct OpenMPDeclarativeConstruct;
267268
struct OmpEndLoopDirective;
@@ -558,7 +559,8 @@ struct ProgramUnit {
558559
common::Indirection<FunctionSubprogram>,
559560
common::Indirection<SubroutineSubprogram>, common::Indirection<Module>,
560561
common::Indirection<Submodule>, common::Indirection<BlockData>,
561-
common::Indirection<CompilerDirective>>
562+
common::Indirection<CompilerDirective>,
563+
common::Indirection<OpenACCRoutineConstruct>>
562564
u;
563565
};
564566

flang/lib/Lower/Bridge.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
316316
globalOmpRequiresSymbol = b.symTab.symbol();
317317
},
318318
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
319+
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
319320
},
320321
u);
321322
}
@@ -328,6 +329,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
328329
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
329330
[&](Fortran::lower::pft::BlockDataUnit &b) {},
330331
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
332+
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
333+
builder = new fir::FirOpBuilder(bridge.getModule(),
334+
bridge.getKindMap());
335+
Fortran::lower::genOpenACCRoutineConstruct(
336+
*this, bridge.getSemanticsContext(), bridge.getModule(),
337+
d.routine, accRoutineInfos);
338+
builder = nullptr;
339+
},
331340
},
332341
u);
333342
}
@@ -2362,6 +2371,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
23622371
genFIR(e);
23632372
}
23642373

2374+
void genFIR(const Fortran::parser::OpenACCRoutineConstruct &acc) {
2375+
// Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
2376+
}
2377+
23652378
void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
23662379
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
23672380
localSymbols.pushScope();

flang/lib/Lower/OpenACC.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -3153,29 +3153,26 @@ static void attachRoutineInfo(mlir::func::FuncOp func,
31533153
mlir::acc::RoutineInfoAttr::get(func.getContext(), routines));
31543154
}
31553155

3156-
static void
3157-
genACC(Fortran::lower::AbstractConverter &converter,
3158-
Fortran::semantics::SemanticsContext &semanticsContext,
3159-
const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
3160-
Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
3156+
void Fortran::lower::genOpenACCRoutineConstruct(
3157+
Fortran::lower::AbstractConverter &converter,
3158+
Fortran::semantics::SemanticsContext &semanticsContext, mlir::ModuleOp &mod,
3159+
const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
3160+
Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
31613161
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
31623162
mlir::Location loc = converter.genLocation(routineConstruct.source);
31633163
std::optional<Fortran::parser::Name> name =
31643164
std::get<std::optional<Fortran::parser::Name>>(routineConstruct.t);
31653165
const auto &clauses =
31663166
std::get<Fortran::parser::AccClauseList>(routineConstruct.t);
3167-
3168-
mlir::ModuleOp mod = builder.getModule();
31693167
mlir::func::FuncOp funcOp;
31703168
std::string funcName;
31713169
if (name) {
31723170
funcName = converter.mangleName(*name->symbol);
3173-
funcOp = builder.getNamedFunction(funcName);
3171+
funcOp = builder.getNamedFunction(mod, funcName);
31743172
} else {
31753173
funcOp = builder.getFunction();
31763174
funcName = funcOp.getName();
31773175
}
3178-
31793176
bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
31803177
hasNohost = false;
31813178
std::optional<std::string> bindName = std::nullopt;
@@ -3391,8 +3388,11 @@ void Fortran::lower::genOpenACCDeclarativeConstruct(
33913388
},
33923389
[&](const Fortran::parser::OpenACCRoutineConstruct
33933390
&routineConstruct) {
3394-
genACC(converter, semanticsContext, routineConstruct,
3395-
accRoutineInfos);
3391+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
3392+
mlir::ModuleOp mod = builder.getModule();
3393+
Fortran::lower::genOpenACCRoutineConstruct(
3394+
converter, semanticsContext, mod, routineConstruct,
3395+
accRoutineInfos);
33963396
},
33973397
},
33983398
accDeclConstruct.u);

flang/lib/Lower/PFTBuilder.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,17 @@ class PFTBuilder {
241241
return enterConstructOrDirective(directive);
242242
}
243243

244+
bool Pre(const parser::OpenACCRoutineConstruct &directive) {
245+
assert(pftParentStack.size() > 0 &&
246+
"At least the Program must be a parent");
247+
if (pftParentStack.back().isA<lower::pft::Program>()) {
248+
addUnit(
249+
lower::pft::OpenACCDirectiveUnit(directive, pftParentStack.back()));
250+
return false;
251+
}
252+
return enterConstructOrDirective(directive);
253+
}
254+
244255
private:
245256
/// Initialize a new module-like unit and make it the builder's focus.
246257
template <typename A>
@@ -1133,6 +1144,9 @@ class PFTDumper {
11331144
[&](const lower::pft::CompilerDirectiveUnit &unit) {
11341145
dumpCompilerDirectiveUnit(outputStream, unit);
11351146
},
1147+
[&](const lower::pft::OpenACCDirectiveUnit &unit) {
1148+
dumpOpenACCDirectiveUnit(outputStream, unit);
1149+
},
11361150
},
11371151
unit);
11381152
}
@@ -1280,6 +1294,16 @@ class PFTDumper {
12801294
outputStream << "\nEnd CompilerDirective\n\n";
12811295
}
12821296

1297+
void
1298+
dumpOpenACCDirectiveUnit(llvm::raw_ostream &outputStream,
1299+
const lower::pft::OpenACCDirectiveUnit &directive) {
1300+
outputStream << getNodeIndex(directive) << " ";
1301+
outputStream << "OpenACCDirective: !$acc ";
1302+
outputStream << directive.get<Fortran::parser::OpenACCRoutineConstruct>()
1303+
.source.ToString();
1304+
outputStream << "\nEnd OpenACCDirective\n\n";
1305+
}
1306+
12831307
template <typename T>
12841308
std::size_t getNodeIndex(const T &node) {
12851309
auto addr = static_cast<const void *>(&node);

flang/lib/Parser/program-parsers.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ static constexpr auto normalProgramUnit{StartNewSubprogram{} >> programUnit /
4646
static constexpr auto globalCompilerDirective{
4747
construct<ProgramUnit>(indirect(compilerDirective))};
4848

49+
static constexpr auto globalOpenACCCompilerDirective{
50+
construct<ProgramUnit>(indirect(skipStuffBeforeStatement >>
51+
"!$ACC "_sptok >> Parser<OpenACCRoutineConstruct>{}))};
52+
4953
// R501 program -> program-unit [program-unit]...
5054
// This is the top-level production for the Fortran language.
5155
// F'2018 6.3.1 defines a program unit as a sequence of one or more lines,
@@ -58,7 +62,8 @@ TYPE_PARSER(
5862
"nonstandard usage: empty source file"_port_en_US,
5963
skipStuffBeforeStatement >> !nextCh >>
6064
pure<std::list<ProgramUnit>>()) ||
61-
some(globalCompilerDirective || normalProgramUnit) /
65+
some(globalCompilerDirective || globalOpenACCCompilerDirective ||
66+
normalProgramUnit) /
6267
skipStuffBeforeStatement))
6368

6469
// R504 specification-part ->

flang/lib/Semantics/program-tree.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ ProgramTree ProgramTree::Build(const parser::CompilerDirective &) {
200200
DIE("ProgramTree::Build() called for CompilerDirective");
201201
}
202202

203+
ProgramTree ProgramTree::Build(const parser::OpenACCRoutineConstruct &) {
204+
DIE("ProgramTree::Build() called for OpenACCRoutineConstruct");
205+
}
206+
203207
const parser::ParentIdentifier &ProgramTree::GetParentId() const {
204208
const auto *stmt{
205209
std::get<const parser::Statement<parser::SubmoduleStmt> *>(stmt_)};

flang/lib/Semantics/program-tree.h

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class ProgramTree {
4343
static ProgramTree Build(const parser::Submodule &);
4444
static ProgramTree Build(const parser::BlockData &);
4545
static ProgramTree Build(const parser::CompilerDirective &);
46+
static ProgramTree Build(const parser::OpenACCRoutineConstruct &);
4647

4748
ENUM_CLASS(Kind, // kind of node
4849
Program, Function, Subroutine, MpSubprogram, Module, Submodule, BlockData)

flang/lib/Semantics/resolve-directives.cpp

+19-10
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ template <typename T> class DirectiveAttributeVisitor {
6161
? std::nullopt
6262
: std::make_optional<DirContext>(dirContext_.back());
6363
}
64+
void PushContext(const parser::CharBlock &source, T dir, Scope &scope) {
65+
dirContext_.emplace_back(source, dir, scope);
66+
}
6467
void PushContext(const parser::CharBlock &source, T dir) {
6568
dirContext_.emplace_back(source, dir, context_.FindScope(source));
6669
}
@@ -115,8 +118,8 @@ template <typename T> class DirectiveAttributeVisitor {
115118

116119
class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
117120
public:
118-
explicit AccAttributeVisitor(SemanticsContext &context)
119-
: DirectiveAttributeVisitor(context) {}
121+
explicit AccAttributeVisitor(SemanticsContext &context, Scope *topScope)
122+
: DirectiveAttributeVisitor(context), topScope_(topScope) {}
120123

121124
template <typename A> void Walk(const A &x) { parser::Walk(x, *this); }
122125
template <typename A> bool Pre(const A &) { return true; }
@@ -281,6 +284,7 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
281284
const llvm::acc::Clause clause, const parser::AccObjectList &objectList);
282285
void AddRoutineInfoToSymbol(
283286
Symbol &, const parser::OpenACCRoutineConstruct &);
287+
Scope *topScope_;
284288
};
285289

286290
// Data-sharing and Data-mapping attributes for data-refs in OpenMP construct
@@ -802,10 +806,6 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCDeclarativeConstruct &x) {
802806
const auto &declDir{
803807
std::get<parser::AccDeclarativeDirective>(declConstruct->t)};
804808
PushContext(declDir.source, llvm::acc::Directive::ACCD_declare);
805-
} else if (const auto *routineConstruct{
806-
std::get_if<parser::OpenACCRoutineConstruct>(&x.u)}) {
807-
const auto &verbatim{std::get<parser::Verbatim>(routineConstruct->t)};
808-
PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine);
809809
}
810810
ClearDataSharingAttributeObjects();
811811
return true;
@@ -994,6 +994,13 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol(
994994
}
995995

996996
bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) {
997+
const auto &verbatim{std::get<parser::Verbatim>(x.t)};
998+
if (topScope_) {
999+
PushContext(
1000+
verbatim.source, llvm::acc::Directive::ACCD_routine, *topScope_);
1001+
} else {
1002+
PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine);
1003+
}
9971004
const auto &optName{std::get<std::optional<parser::Name>>(x.t)};
9981005
if (optName) {
9991006
if (Symbol *sym = ResolveFctName(*optName)) {
@@ -1005,7 +1012,9 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) {
10051012
(*optName).source);
10061013
}
10071014
} else {
1008-
AddRoutineInfoToSymbol(*currScope().symbol(), x);
1015+
if (currScope().symbol()) {
1016+
AddRoutineInfoToSymbol(*currScope().symbol(), x);
1017+
}
10091018
}
10101019
return true;
10111020
}
@@ -2190,10 +2199,10 @@ void OmpAttributeVisitor::CheckMultipleAppearances(
21902199
}
21912200
}
21922201

2193-
void ResolveAccParts(
2194-
SemanticsContext &context, const parser::ProgramUnit &node) {
2202+
void ResolveAccParts(SemanticsContext &context, const parser::ProgramUnit &node,
2203+
Scope *topScope) {
21952204
if (context.IsEnabled(common::LanguageFeature::OpenACC)) {
2196-
AccAttributeVisitor{context}.Walk(node);
2205+
AccAttributeVisitor{context, topScope}.Walk(node);
21972206
}
21982207
}
21992208

flang/lib/Semantics/resolve-directives.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ struct ProgramUnit;
1616
} // namespace Fortran::parser
1717

1818
namespace Fortran::semantics {
19-
19+
class Scope;
2020
class SemanticsContext;
2121

2222
// Name resolution for OpenACC and OpenMP directives
23-
void ResolveAccParts(SemanticsContext &, const parser::ProgramUnit &);
23+
void ResolveAccParts(
24+
SemanticsContext &, const parser::ProgramUnit &, Scope *topScope = {});
2425
void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &);
2526
void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &);
2627

flang/lib/Semantics/resolve-names.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -8323,6 +8323,11 @@ bool ResolveNamesVisitor::Pre(const parser::ProgramUnit &x) {
83238323
// TODO: global directives
83248324
return true;
83258325
}
8326+
if (std::holds_alternative<
8327+
common::Indirection<parser::OpenACCRoutineConstruct>>(x.u)) {
8328+
ResolveAccParts(context(), x, &topScope_);
8329+
return false;
8330+
}
83268331
auto root{ProgramTree::Build(x)};
83278332
SetScope(topScope_);
83288333
ResolveSpecificationParts(root);
@@ -8335,7 +8340,8 @@ bool ResolveNamesVisitor::Pre(const parser::ProgramUnit &x) {
83358340

83368341
template <typename A> std::set<SourceName> GetUses(const A &x) {
83378342
std::set<SourceName> uses;
8338-
if constexpr (!std::is_same_v<A, parser::CompilerDirective>) {
8343+
if constexpr (!std::is_same_v<A, parser::CompilerDirective> &&
8344+
!std::is_same_v<A, parser::OpenACCRoutineConstruct>) {
83398345
const auto &spec{std::get<parser::SpecificationPart>(x.t)};
83408346
const auto &unitUses{std::get<
83418347
std::list<parser::Statement<common::Indirection<parser::UseStmt>>>>(
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
! This test checks lowering of OpenACC routine directive.
2+
3+
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
4+
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
5+
6+
subroutine sub1(a, n)
7+
integer :: n
8+
real :: a(n)
9+
end subroutine sub1
10+
11+
!$acc routine(sub1)
12+
13+
program test
14+
integer, parameter :: N = 10
15+
real :: a(N)
16+
call sub1(a, N)
17+
end program
18+
19+
! CHECK-LABEL: acc.routine @acc_routine_0 func(@_QPsub1)
20+
21+
! CHECK: func.func @_QPsub1(%ar{{.*}}: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>}

0 commit comments

Comments
 (0)