Skip to content

Commit 0d8df98

Browse files
committed
[mlir] Allow for using OpPassManager in pass options
This significantly simplifies the boilerplate necessary for passes to define nested pass pipelines. Differential Revision: https://reviews.llvm.org/D122880
1 parent 6edef13 commit 0d8df98

File tree

10 files changed

+219
-37
lines changed

10 files changed

+219
-37
lines changed

mlir/include/mlir/Pass/PassManager.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ class OpPassManager {
7373
return {begin(), end()};
7474
}
7575

76+
/// Returns true if the pass manager has no passes.
77+
bool empty() const { return begin() == end(); }
78+
7679
/// Nest a new operation pass manager for the given operation kind under this
7780
/// pass manager.
7881
OpPassManager &nest(StringAttr nestedName);
@@ -110,7 +113,7 @@ class OpPassManager {
110113
/// of pipelines.
111114
/// Note: The quality of the string representation depends entirely on the
112115
/// the correctness of per-pass overrides of Pass::printAsTextualPipeline.
113-
void printAsTextualPipeline(raw_ostream &os);
116+
void printAsTextualPipeline(raw_ostream &os) const;
114117

115118
/// Raw dump of the pass manager to llvm::errs().
116119
void dump();

mlir/include/mlir/Pass/PassOptions.h

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include <memory>
2424

2525
namespace mlir {
26+
class OpPassManager;
27+
2628
namespace detail {
2729
namespace pass_options {
2830
/// Parse a string containing a list of comma-delimited elements, invoking the
@@ -158,7 +160,7 @@ class PassOptions : protected llvm::cl::SubCommand {
158160
public OptionBase {
159161
public:
160162
template <typename... Args>
161-
Option(PassOptions &parent, StringRef arg, Args &&... args)
163+
Option(PassOptions &parent, StringRef arg, Args &&...args)
162164
: llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>(
163165
arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
164166
assert(!this->isPositional() && !this->isSink() &&
@@ -319,7 +321,8 @@ class PassOptions : protected llvm::cl::SubCommand {
319321
/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
320322
/// ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
321323
/// };
322-
template <typename T> class PassPipelineOptions : public detail::PassOptions {
324+
template <typename T>
325+
class PassPipelineOptions : public detail::PassOptions {
323326
public:
324327
/// Factory that parses the provided options and returns a unique_ptr to the
325328
/// struct.
@@ -335,7 +338,6 @@ template <typename T> class PassPipelineOptions : public detail::PassOptions {
335338
/// any options.
336339
struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
337340
};
338-
339341
} // namespace mlir
340342

341343
//===----------------------------------------------------------------------===//
@@ -407,8 +409,92 @@ class parser<SmallVector<T, N>>
407409
public:
408410
parser(Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
409411
};
410-
} // end namespace cl
411-
} // end namespace llvm
412412

413-
#endif // MLIR_PASS_PASSOPTIONS_H_
413+
//===----------------------------------------------------------------------===//
414+
// OpPassManager: OptionValue
414415

416+
template <>
417+
struct OptionValue<mlir::OpPassManager> final : GenericOptionValue {
418+
using WrapperType = mlir::OpPassManager;
419+
420+
OptionValue();
421+
OptionValue(const mlir::OpPassManager &value);
422+
OptionValue<mlir::OpPassManager> &operator=(const mlir::OpPassManager &rhs);
423+
~OptionValue();
424+
425+
/// Returns if the current option has a value.
426+
bool hasValue() const { return value.get(); }
427+
428+
/// Returns the current value of the option.
429+
mlir::OpPassManager &getValue() const {
430+
assert(hasValue() && "invalid option value");
431+
return *value;
432+
}
433+
434+
/// Set the value of the option.
435+
void setValue(const mlir::OpPassManager &newValue);
436+
void setValue(StringRef pipelineStr);
437+
438+
/// Compare the option with the provided value.
439+
bool compare(const mlir::OpPassManager &rhs) const;
440+
bool compare(const GenericOptionValue &rhs) const override {
441+
const auto &rhsOV =
442+
static_cast<const OptionValue<mlir::OpPassManager> &>(rhs);
443+
if (!rhsOV.hasValue())
444+
return false;
445+
return compare(rhsOV.getValue());
446+
}
447+
448+
private:
449+
void anchor() override;
450+
451+
/// The underlying pass manager. We use a unique_ptr to avoid the need for the
452+
/// full type definition.
453+
std::unique_ptr<mlir::OpPassManager> value;
454+
};
455+
456+
//===----------------------------------------------------------------------===//
457+
// OpPassManager: Parser
458+
459+
extern template class basic_parser<mlir::OpPassManager>;
460+
461+
template <>
462+
class parser<mlir::OpPassManager> : public basic_parser<mlir::OpPassManager> {
463+
public:
464+
/// A utility struct used when parsing a pass manager that prevents the need
465+
/// for a default constructor on OpPassManager.
466+
struct ParsedPassManager {
467+
ParsedPassManager();
468+
ParsedPassManager(ParsedPassManager &&);
469+
~ParsedPassManager();
470+
operator const mlir::OpPassManager &() const {
471+
assert(value && "parsed value was invalid");
472+
return *value;
473+
}
474+
475+
std::unique_ptr<mlir::OpPassManager> value;
476+
};
477+
using parser_data_type = ParsedPassManager;
478+
using OptVal = OptionValue<mlir::OpPassManager>;
479+
480+
parser(Option &opt) : basic_parser(opt) {}
481+
482+
bool parse(Option &, StringRef, StringRef arg, ParsedPassManager &value);
483+
484+
/// Print an instance of the underling option value to the given stream.
485+
static void print(raw_ostream &os, const mlir::OpPassManager &value);
486+
487+
// Overload in subclass to provide a better default value.
488+
StringRef getValueName() const override { return "pass-manager"; }
489+
490+
void printOptionDiff(const Option &opt, mlir::OpPassManager &pm,
491+
const OptVal &defaultValue, size_t globalWidth) const;
492+
493+
// An out-of-line virtual method to provide a 'home' for this class.
494+
void anchor() override;
495+
};
496+
497+
} // namespace cl
498+
} // namespace llvm
499+
500+
#endif // MLIR_PASS_PASSOPTIONS_H_

mlir/include/mlir/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def Inliner : Pass<"inline"> {
8383
let options = [
8484
Option<"defaultPipelineStr", "default-pipeline", "std::string",
8585
/*default=*/"", "The default optimizer pipeline used for callables">,
86-
ListOption<"opPipelineStrs", "op-pipelines", "std::string",
86+
ListOption<"opPipelineList", "op-pipelines", "OpPassManager",
8787
"Callable operation specific optimizer pipelines (in the form "
8888
"of `dialect.op(pipeline)`)">,
8989
Option<"maxInliningIterations", "max-iterations", "unsigned",

mlir/lib/Pass/Pass.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@ void Pass::copyOptionValuesFrom(const Pass *other) {
5454
void Pass::printAsTextualPipeline(raw_ostream &os) {
5555
// Special case for adaptors to use the 'op_name(sub_passes)' format.
5656
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
57-
llvm::interleaveComma(adaptor->getPassManagers(), os,
58-
[&](OpPassManager &pm) {
59-
os << pm.getOpName() << "(";
60-
pm.printAsTextualPipeline(os);
61-
os << ")";
62-
});
57+
llvm::interleave(
58+
adaptor->getPassManagers(),
59+
[&](OpPassManager &pm) {
60+
os << pm.getOpName() << "(";
61+
pm.printAsTextualPipeline(os);
62+
os << ")";
63+
},
64+
[&] { os << ","; });
6365
return;
6466
}
6567
// Otherwise, print the pass argument followed by its options. If the pass
@@ -295,14 +297,17 @@ OperationName OpPassManager::getOpName(MLIRContext &context) const {
295297
/// Prints out the given passes as the textual representation of a pipeline.
296298
static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
297299
raw_ostream &os) {
298-
llvm::interleaveComma(passes, os, [&](const std::unique_ptr<Pass> &pass) {
299-
pass->printAsTextualPipeline(os);
300-
});
300+
llvm::interleave(
301+
passes,
302+
[&](const std::unique_ptr<Pass> &pass) {
303+
pass->printAsTextualPipeline(os);
304+
},
305+
[&] { os << ","; });
301306
}
302307

303308
/// Prints out the passes of the pass manager as the textual representation
304309
/// of pipelines.
305-
void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
310+
void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
306311
::printAsTextualPipeline(impl->passes, os);
307312
}
308313

mlir/lib/Pass/PassRegistry.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,104 @@ size_t detail::PassOptions::getOptionWidth() const {
332332
return max;
333333
}
334334

335+
//===----------------------------------------------------------------------===//
336+
// MLIR Options
337+
//===----------------------------------------------------------------------===//
338+
339+
//===----------------------------------------------------------------------===//
340+
// OpPassManager: OptionValue
341+
342+
llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
343+
llvm::cl::OptionValue<OpPassManager>::OptionValue(
344+
const mlir::OpPassManager &value) {
345+
setValue(value);
346+
}
347+
llvm::cl::OptionValue<OpPassManager> &
348+
llvm::cl::OptionValue<OpPassManager>::operator=(
349+
const mlir::OpPassManager &rhs) {
350+
setValue(rhs);
351+
return *this;
352+
}
353+
354+
llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
355+
356+
void llvm::cl::OptionValue<OpPassManager>::setValue(
357+
const OpPassManager &newValue) {
358+
if (hasValue())
359+
*value = newValue;
360+
else
361+
value = std::make_unique<mlir::OpPassManager>(newValue);
362+
}
363+
void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
364+
FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
365+
assert(succeeded(pipeline) && "invalid pass pipeline");
366+
setValue(*pipeline);
367+
}
368+
369+
bool llvm::cl::OptionValue<OpPassManager>::compare(
370+
const mlir::OpPassManager &rhs) const {
371+
std::string lhsStr, rhsStr;
372+
{
373+
raw_string_ostream lhsStream(lhsStr);
374+
value->printAsTextualPipeline(lhsStream);
375+
376+
raw_string_ostream rhsStream(rhsStr);
377+
rhs.printAsTextualPipeline(rhsStream);
378+
}
379+
380+
// Use the textual format for pipeline comparisons.
381+
return lhsStr == rhsStr;
382+
}
383+
384+
void llvm::cl::OptionValue<OpPassManager>::anchor() {}
385+
386+
//===----------------------------------------------------------------------===//
387+
// OpPassManager: Parser
388+
389+
namespace llvm {
390+
namespace cl {
391+
template class basic_parser<OpPassManager>;
392+
} // namespace cl
393+
} // namespace llvm
394+
395+
bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
396+
ParsedPassManager &value) {
397+
FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
398+
if (failed(pipeline))
399+
return true;
400+
value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
401+
return false;
402+
}
403+
404+
void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
405+
const OpPassManager &value) {
406+
value.printAsTextualPipeline(os);
407+
}
408+
409+
void llvm::cl::parser<OpPassManager>::printOptionDiff(
410+
const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
411+
size_t globalWidth) const {
412+
printOptionName(opt, globalWidth);
413+
outs() << "= ";
414+
pm.printAsTextualPipeline(outs());
415+
416+
if (defaultValue.hasValue()) {
417+
outs().indent(2) << " (default: ";
418+
defaultValue.getValue().printAsTextualPipeline(outs());
419+
outs() << ")";
420+
}
421+
outs() << "\n";
422+
}
423+
424+
void llvm::cl::parser<OpPassManager>::anchor() {}
425+
426+
llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
427+
default;
428+
llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
429+
ParsedPassManager &&) = default;
430+
llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
431+
default;
432+
335433
//===----------------------------------------------------------------------===//
336434
// TextualPassPipeline Parser
337435
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Inliner.cpp

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -585,14 +585,8 @@ InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
585585
return;
586586

587587
// Update the option for the op specific optimization pipelines.
588-
for (auto &it : opPipelines) {
589-
std::string pipeline;
590-
llvm::raw_string_ostream pipelineOS(pipeline);
591-
pipelineOS << it.getKey() << "(";
592-
it.second.printAsTextualPipeline(pipelineOS);
593-
pipelineOS << ")";
594-
opPipelineStrs.addValue(pipeline);
595-
}
588+
for (auto &it : opPipelines)
589+
opPipelineList.addValue(it.second);
596590
this->opPipelines.emplace_back(std::move(opPipelines));
597591
}
598592

@@ -751,15 +745,9 @@ LogicalResult InlinerPass::initializeOptions(StringRef options) {
751745

752746
// Initialize the op specific pass pipelines.
753747
llvm::StringMap<OpPassManager> pipelines;
754-
for (StringRef pipeline : opPipelineStrs) {
755-
// Skip empty pipelines.
756-
if (pipeline.empty())
757-
continue;
758-
FailureOr<OpPassManager> pm = parsePassPipeline(pipeline);
759-
if (failed(pm))
760-
return failure();
761-
pipelines.try_emplace(pm->getOpName(), std::move(*pm));
762-
}
748+
for (OpPassManager pipeline : opPipelineList)
749+
if (!pipeline.empty())
750+
pipelines.try_emplace(pipeline.getOpName(), pipeline);
763751
opPipelines.assign({std::move(pipelines)});
764752

765753
return success();

mlir/lib/Transforms/PassDetail.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define TRANSFORMS_PASSDETAIL_H_
1111

1212
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Pass/PassManager.h"
1314
#include "mlir/Transforms/Passes.h"
1415

1516
namespace mlir {

mlir/test/Pass/crash-recovery.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ module @inner_mod1 {
2020
module @foo {}
2121
}
2222

23-
// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass, test-pass-crash)'
23+
// REPRO: configuration: -pass-pipeline='builtin.module(test-module-pass,test-pass-crash)'
2424

2525
// REPRO: module @inner_mod1
2626
// REPRO: module @foo {

mlir/test/Pass/pipeline-options-parsing.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414

1515
// CHECK_1: test-options-pass{list=1,2,3,4,5 string=nested_pipeline{arg1=10 arg2=" {} " arg3=true} string-list=a,b,c,d}
1616
// CHECK_2: test-options-pass{list=1 string= string-list=a,b}
17-
// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }), func.func(test-options-pass{list=1,2,3,4 string= }))
17+
// CHECK_3: builtin.module(func.func(test-options-pass{list=3 string= }),func.func(test-options-pass{list=1,2,3,4 string= }))

mlir/test/Transforms/inlining.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s
33
// RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC
44
// RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY
5+
// RUN: mlir-opt %s -inline='op-pipelines=func.func(canonicalize,cse)' | FileCheck %s --check-prefix INLINE_SIMPLIFY
56

67
// Inline a function that takes an argument.
78
func @func_with_arg(%c : i32) -> i32 {

0 commit comments

Comments
 (0)