Skip to content

Commit baaf0c9

Browse files
authored
[CodeGen] Support start/stop in CodeGenPassBuilder (llvm#70912)
Add `-start/stop-before/after` support for CodeGenPassBuilder. Part of llvm#69879.
1 parent 9d1dada commit baaf0c9

File tree

5 files changed

+206
-21
lines changed

5 files changed

+206
-21
lines changed

llvm/include/llvm/CodeGen/CodeGenPassBuilder.h

Lines changed: 113 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "llvm/CodeGen/ShadowStackGCLowering.h"
4545
#include "llvm/CodeGen/SjLjEHPrepare.h"
4646
#include "llvm/CodeGen/StackProtector.h"
47+
#include "llvm/CodeGen/TargetPassConfig.h"
4748
#include "llvm/CodeGen/UnreachableBlockElim.h"
4849
#include "llvm/CodeGen/WasmEHPrepare.h"
4950
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -176,73 +177,80 @@ template <typename DerivedT> class CodeGenPassBuilder {
176177
// Function object to maintain state while adding codegen IR passes.
177178
class AddIRPass {
178179
public:
179-
AddIRPass(ModulePassManager &MPM) : MPM(MPM) {}
180+
AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {}
180181
~AddIRPass() {
181182
if (!FPM.isEmpty())
182183
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
183184
}
184185

185-
template <typename PassT> void operator()(PassT &&Pass) {
186+
template <typename PassT>
187+
void operator()(PassT &&Pass, StringRef Name = PassT::name()) {
186188
static_assert((is_detected<is_function_pass_t, PassT>::value ||
187189
is_detected<is_module_pass_t, PassT>::value) &&
188190
"Only module pass and function pass are supported.");
189191

192+
if (!PB.runBeforeAdding(Name))
193+
return;
194+
190195
// Add Function Pass
191196
if constexpr (is_detected<is_function_pass_t, PassT>::value) {
192197
FPM.addPass(std::forward<PassT>(Pass));
198+
199+
for (auto &C : PB.AfterCallbacks)
200+
C(Name);
193201
} else {
194202
// Add Module Pass
195203
if (!FPM.isEmpty()) {
196204
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
197205
FPM = FunctionPassManager();
198206
}
207+
199208
MPM.addPass(std::forward<PassT>(Pass));
209+
210+
for (auto &C : PB.AfterCallbacks)
211+
C(Name);
200212
}
201213
}
202214

203215
private:
204216
ModulePassManager &MPM;
205217
FunctionPassManager FPM;
218+
const DerivedT &PB;
206219
};
207220

208221
// Function object to maintain state while adding codegen machine passes.
209222
class AddMachinePass {
210223
public:
211-
AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
224+
AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
225+
: PM(PM), PB(PB) {}
212226

213227
template <typename PassT> void operator()(PassT &&Pass) {
214228
static_assert(
215229
is_detected<has_key_t, PassT>::value,
216230
"Machine function pass must define a static member variable `Key`.");
217-
for (auto &C : BeforeCallbacks)
218-
if (!C(&PassT::Key))
219-
return;
231+
232+
if (!PB.runBeforeAdding(PassT::name()))
233+
return;
234+
220235
PM.addPass(std::forward<PassT>(Pass));
221-
for (auto &C : AfterCallbacks)
222-
C(&PassT::Key);
236+
237+
for (auto &C : PB.AfterCallbacks)
238+
C(PassT::name());
223239
}
224240

225241
template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
226-
AfterCallbacks.emplace_back(
242+
PB.AfterCallbacks.emplace_back(
227243
[this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
228244
if (PassID == ID)
229245
this->PM.addPass(std::move(Pass));
230246
});
231247
}
232248

233-
void disablePass(MachinePassKey *ID) {
234-
BeforeCallbacks.emplace_back(
235-
[ID](MachinePassKey *PassID) { return PassID != ID; });
236-
}
237-
238249
MachineFunctionPassManager releasePM() { return std::move(PM); }
239250

240251
private:
241252
MachineFunctionPassManager &PM;
242-
SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
243-
BeforeCallbacks;
244-
SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
245-
AfterCallbacks;
253+
const DerivedT &PB;
246254
};
247255

248256
LLVMTargetMachine &TM;
@@ -473,20 +481,43 @@ template <typename DerivedT> class CodeGenPassBuilder {
473481
const DerivedT &derived() const {
474482
return static_cast<const DerivedT &>(*this);
475483
}
484+
485+
bool runBeforeAdding(StringRef Name) const {
486+
bool ShouldAdd = true;
487+
for (auto &C : BeforeCallbacks)
488+
ShouldAdd &= C(Name);
489+
return ShouldAdd;
490+
}
491+
492+
void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;
493+
494+
Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;
495+
496+
mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
497+
BeforeCallbacks;
498+
mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;
499+
500+
/// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
501+
mutable bool Started = true;
502+
mutable bool Stopped = true;
476503
};
477504

478505
template <typename Derived>
479506
Error CodeGenPassBuilder<Derived>::buildPipeline(
480507
ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
481508
raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
482509
CodeGenFileType FileType) const {
483-
AddIRPass addIRPass(MPM);
510+
auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
511+
if (!StartStopInfo)
512+
return StartStopInfo.takeError();
513+
setStartStopPasses(*StartStopInfo);
514+
AddIRPass addIRPass(MPM, derived());
484515
// `ProfileSummaryInfo` is always valid.
485516
addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
486517
addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
487518
addISelPasses(addIRPass);
488519

489-
AddMachinePass addPass(MFPM);
520+
AddMachinePass addPass(MFPM, derived());
490521
if (auto Err = addCoreISelPasses(addPass))
491522
return std::move(Err);
492523

@@ -499,6 +530,68 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
499530
});
500531

501532
addPass(FreeMachineFunctionPass());
533+
return verifyStartStop(*StartStopInfo);
534+
}
535+
536+
template <typename Derived>
537+
void CodeGenPassBuilder<Derived>::setStartStopPasses(
538+
const TargetPassConfig::StartStopInfo &Info) const {
539+
if (!Info.StartPass.empty()) {
540+
Started = false;
541+
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StartAfter,
542+
Count = 0u](StringRef ClassName) mutable {
543+
if (Count == Info.StartInstanceNum) {
544+
if (AfterFlag) {
545+
AfterFlag = false;
546+
Started = true;
547+
}
548+
return Started;
549+
}
550+
551+
auto PassName = PIC->getPassNameForClassName(ClassName);
552+
if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum)
553+
Started = !Info.StartAfter;
554+
555+
return Started;
556+
});
557+
}
558+
559+
if (!Info.StopPass.empty()) {
560+
Stopped = false;
561+
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StopAfter,
562+
Count = 0u](StringRef ClassName) mutable {
563+
if (Count == Info.StopInstanceNum) {
564+
if (AfterFlag) {
565+
AfterFlag = false;
566+
Stopped = true;
567+
}
568+
return !Stopped;
569+
}
570+
571+
auto PassName = PIC->getPassNameForClassName(ClassName);
572+
if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum)
573+
Stopped = !Info.StopAfter;
574+
return !Stopped;
575+
});
576+
}
577+
}
578+
579+
template <typename Derived>
580+
Error CodeGenPassBuilder<Derived>::verifyStartStop(
581+
const TargetPassConfig::StartStopInfo &Info) const {
582+
if (Started && Stopped)
583+
return Error::success();
584+
585+
if (!Started)
586+
return make_error<StringError>(
587+
"Can't find start pass \"" +
588+
PIC->getPassNameForClassName(Info.StartPass) + "\".",
589+
std::make_error_code(std::errc::invalid_argument));
590+
if (!Stopped)
591+
return make_error<StringError>(
592+
"Can't find stop pass \"" +
593+
PIC->getPassNameForClassName(Info.StopPass) + "\".",
594+
std::make_error_code(std::errc::invalid_argument));
502595
return Error::success();
503596
}
504597

llvm/include/llvm/CodeGen/TargetPassConfig.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "llvm/Pass.h"
1717
#include "llvm/Support/CodeGen.h"
18+
#include "llvm/Support/Error.h"
1819
#include <cassert>
1920
#include <string>
2021

@@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
176177
static std::string
177178
getLimitedCodeGenPipelineReason(const char *Separator = "/");
178179

180+
struct StartStopInfo {
181+
bool StartAfter;
182+
bool StopAfter;
183+
unsigned StartInstanceNum;
184+
unsigned StopInstanceNum;
185+
StringRef StartPass;
186+
StringRef StopPass;
187+
};
188+
189+
/// Returns pass name in `-stop-before` or `-stop-after`
190+
/// NOTE: New pass manager migration only
191+
static Expected<StartStopInfo>
192+
getStartStopInfo(PassInstrumentationCallbacks &PIC);
193+
179194
void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }
180195

181196
bool getEnableTailMerge() const { return EnableTailMerge; }

llvm/lib/CodeGen/TargetPassConfig.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
609609
registerPartialPipelineCallback(PIC, LLVMTM);
610610
}
611611

612+
Expected<TargetPassConfig::StartStopInfo>
613+
TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
614+
auto [StartBefore, StartBeforeInstanceNum] =
615+
getPassNameAndInstanceNum(StartBeforeOpt);
616+
auto [StartAfter, StartAfterInstanceNum] =
617+
getPassNameAndInstanceNum(StartAfterOpt);
618+
auto [StopBefore, StopBeforeInstanceNum] =
619+
getPassNameAndInstanceNum(StopBeforeOpt);
620+
auto [StopAfter, StopAfterInstanceNum] =
621+
getPassNameAndInstanceNum(StopAfterOpt);
622+
623+
if (!StartBefore.empty() && !StartAfter.empty())
624+
return make_error<StringError>(
625+
Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
626+
std::make_error_code(std::errc::invalid_argument));
627+
if (!StopBefore.empty() && !StopAfter.empty())
628+
return make_error<StringError>(
629+
Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
630+
std::make_error_code(std::errc::invalid_argument));
631+
632+
StartStopInfo Result;
633+
Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
634+
Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
635+
Result.StartInstanceNum =
636+
StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
637+
Result.StopInstanceNum =
638+
StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
639+
Result.StartAfter = !StartAfter.empty();
640+
Result.StopAfter = !StopAfter.empty();
641+
Result.StartInstanceNum += Result.StartInstanceNum == 0;
642+
Result.StopInstanceNum += Result.StopInstanceNum == 0;
643+
return Result;
644+
}
645+
612646
// Out of line constructor provides default values for pass options and
613647
// registers all common codegen passes.
614648
TargetPassConfig::TargetPassConfig(LLVMTargetMachine &TM, PassManagerBase &pm)

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
#include "llvm/CodeGen/ShadowStackGCLowering.h"
9494
#include "llvm/CodeGen/SjLjEHPrepare.h"
9595
#include "llvm/CodeGen/StackProtector.h"
96+
#include "llvm/CodeGen/TargetPassConfig.h"
9697
#include "llvm/CodeGen/TypePromotion.h"
9798
#include "llvm/CodeGen/WasmEHPrepare.h"
9899
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -316,7 +317,8 @@ namespace {
316317
/// We currently only use this for --print-before/after.
317318
bool shouldPopulateClassToPassNames() {
318319
return PrintPipelinePasses || !printBeforePasses().empty() ||
319-
!printAfterPasses().empty() || !isFilterPassesEmpty();
320+
!printAfterPasses().empty() || !isFilterPassesEmpty() ||
321+
TargetPassConfig::hasLimitedCodeGenPipeline();
320322
}
321323

322324
// A pass for testing -print-on-crash.

llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,45 @@ TEST_F(CodeGenPassBuilderTest, basic) {
138138
EXPECT_EQ(MIRPipeline, ExpectedMIRPipeline);
139139
}
140140

141+
// TODO: Move this to lit test when llc support new pm.
142+
TEST_F(CodeGenPassBuilderTest, start_stop) {
143+
static const char *argv[] = {
144+
"test",
145+
"-start-after=no-op-module",
146+
"-stop-before=no-op-function,2",
147+
};
148+
int argc = std::size(argv);
149+
cl::ParseCommandLineOptions(argc, argv);
150+
151+
LoopAnalysisManager LAM;
152+
FunctionAnalysisManager FAM;
153+
CGSCCAnalysisManager CGAM;
154+
ModuleAnalysisManager MAM;
155+
156+
PassInstrumentationCallbacks PIC;
157+
DummyCodeGenPassBuilder CGPB(*TM, getCGPassBuilderOption(), &PIC);
158+
PipelineTuningOptions PTO;
159+
PassBuilder PB(TM.get(), PTO, std::nullopt, &PIC);
160+
161+
PB.registerModuleAnalyses(MAM);
162+
PB.registerCGSCCAnalyses(CGAM);
163+
PB.registerFunctionAnalyses(FAM);
164+
PB.registerLoopAnalyses(LAM);
165+
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
166+
167+
ModulePassManager MPM;
168+
MachineFunctionPassManager MFPM;
169+
170+
Error Err =
171+
CGPB.buildPipeline(MPM, MFPM, outs(), nullptr, CodeGenFileType::Null);
172+
EXPECT_FALSE(Err);
173+
std::string IRPipeline;
174+
raw_string_ostream IROS(IRPipeline);
175+
MPM.printPipeline(IROS, [&PIC](StringRef Name) {
176+
auto PassName = PIC.getPassNameForClassName(Name);
177+
return PassName.empty() ? Name : PassName;
178+
});
179+
EXPECT_EQ(IRPipeline, "function(no-op-function)");
180+
}
181+
141182
} // namespace

0 commit comments

Comments
 (0)