Skip to content

Commit 359ba0b

Browse files
committed
[mlir][CFGToSCF] Add interface changes for downstream projects
This is a follow-up to https://reviews.llvm.org/D156889 Downstream projects may have more complicated ops than the control flow ops upstream and therefore need a more powerful interface to support the lifting process. Use cases include the propagation of (inherent) metadata that was previously on the control flow ops and now needs to be lifted to structured control flow ops. Since the lifting process is inherently non-local in respect to the function-body, we require stronger guarantees from the interface. This patch therefore makes two changes to the interface: * Passes the terminator that is being replaced to `createStructuredBranchRegionTerminatorOp` * Adds as precondition to `createCFGSwitchOp` that its predecessors are already correctly established Asserts have been added to verify these were it makes sense and to correctly state intent. I have not added tests purely because testing preconditions like these is not really feasible (and incredibly specific). Differential Revision: https://reviews.llvm.org/D157981
1 parent fb0c50b commit 359ba0b

File tree

4 files changed

+44
-31
lines changed

4 files changed

+44
-31
lines changed

mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,9 @@ class ControlFlowToSCFTransformation : public CFGToSCFInterface {
3232
MutableArrayRef<Region> regions) override;
3333

3434
/// Creates an `scf.yield` op returning the given results.
35-
LogicalResult
36-
createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder,
37-
Operation *branchRegionOp,
38-
ValueRange results) override;
35+
LogicalResult createStructuredBranchRegionTerminatorOp(
36+
Location loc, OpBuilder &builder, Operation *branchRegionOp,
37+
Operation *replacedControlFlowOp, ValueRange results) override;
3938

4039
/// Creates an `scf.while` op. The loop body is made the before-region of the
4140
/// while op and terminated with an `scf.condition` op. The after-region does

mlir/include/mlir/Transforms/CFGToSCF.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@ class CFGToSCFInterface {
4242

4343
/// Creates a return-like terminator for a branch region of the op returned
4444
/// by `createStructuredBranchRegionOp`. `branchRegionOp` is the operation
45-
/// returned by `createStructuredBranchRegionOp` while `results` are the
46-
/// values that should be returned by the branch region.
47-
virtual LogicalResult
48-
createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder,
49-
Operation *branchRegionOp,
50-
ValueRange results) = 0;
45+
/// returned by `createStructuredBranchRegionOp`.
46+
/// `replacedControlFlowOp` is the control flow op being replaced by the
47+
/// terminator or nullptr if the terminator is not replacing any existing
48+
/// control flow op. `results` are the values that should be returned by the
49+
/// branch region.
50+
virtual LogicalResult createStructuredBranchRegionTerminatorOp(
51+
Location loc, OpBuilder &builder, Operation *branchRegionOp,
52+
Operation *replacedControlFlowOp, ValueRange results) = 0;
5153

5254
/// Creates a structured control flow operation representing a do-while loop.
5355
/// The do-while loop is expected to have the exact same result types as the
@@ -77,8 +79,10 @@ class CFGToSCFInterface {
7779
/// `caseDestinations` or `defaultDest`. This is used by the transformation
7880
/// for intermediate transformations before lifting to structured control
7981
/// flow. The switch op branches based on `flag` which is guaranteed to be of
80-
/// the same type as values returned by `getCFGSwitchValue`. Note:
81-
/// `caseValues` and other related ranges may be empty to represent an
82+
/// the same type as values returned by `getCFGSwitchValue`. The insertion
83+
/// block of the builder is guaranteed to have its predecessors already set
84+
/// to create an equivalent CFG after this operation.
85+
/// Note: `caseValues` and other related ranges may be empty to represent an
8286
/// unconditional branch.
8387
virtual void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag,
8488
ArrayRef<unsigned> caseValues,

mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
7676
LogicalResult
7777
ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
7878
Location loc, OpBuilder &builder, Operation *branchRegionOp,
79-
ValueRange results) {
79+
Operation *replacedControlFlowOp, ValueRange results) {
8080
builder.create<scf::YieldOp>(loc, results);
8181
return success();
8282
}

mlir/lib/Transforms/Utils/CFGToSCF.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ class EdgeMultiplexer {
306306
/// Creates a switch op using `builder` which dispatches to the original
307307
/// successors of the edges passed to `create` minus the ones in `excluded`.
308308
/// The builder's insertion point has to be in a block dominated by the
309-
/// multiplexer block.
309+
/// multiplexer block. All edges to the multiplexer block must have already
310+
/// been redirected using `redirectEdge`.
310311
void createSwitch(
311312
Location loc, OpBuilder &builder, CFGToSCFInterface &interface,
312313
const SmallPtrSetImpl<Block *> &excluded = SmallPtrSet<Block *, 1>{}) {
@@ -337,6 +338,8 @@ class EdgeMultiplexer {
337338
Block *defaultDest = caseDestinations.pop_back_val();
338339
ValueRange defaultArgs = caseArguments.pop_back_val();
339340

341+
assert(!builder.getInsertionBlock()->hasNoPredecessors() &&
342+
"Edges need to be redirected prior to creating switch.");
340343
interface.createCFGSwitchOp(loc, builder, realDiscriminator, caseValues,
341344
caseDestinations, caseArguments, defaultDest,
342345
defaultArgs);
@@ -507,12 +510,14 @@ createSingleEntryBlock(Location loc, ArrayRef<Edge> entryEdges,
507510
loc, llvm::map_to_vector(entryEdges, std::mem_fn(&Edge::getSuccessor)),
508511
getSwitchValue, getUndefValue);
509512

510-
auto builder = OpBuilder::atBlockBegin(result.getMultiplexerBlock());
511-
result.createSwitch(loc, builder, interface);
512-
513+
// Redirect the edges prior to creating the switch op.
514+
// We guarantee that predecessors are up to date.
513515
for (Edge edge : entryEdges)
514516
result.redirectEdge(edge);
515517

518+
auto builder = OpBuilder::atBlockBegin(result.getMultiplexerBlock());
519+
result.createSwitch(loc, builder, interface);
520+
516521
return result;
517522
}
518523

@@ -565,6 +570,17 @@ static FailureOr<StructuredLoopProperties> createSingleExitingLatch(
565570
// Since this is a loop, all back edges point to the same loop header.
566571
Block *loopHeader = backEdges.front().getSuccessor();
567572

573+
// Redirect the edges prior to creating the switch op.
574+
// We guarantee that predecessors are up to date.
575+
576+
// Redirecting back edges with `shouldRepeat` as 1.
577+
for (Edge backEdge : backEdges)
578+
multiplexer.redirectEdge(backEdge, /*extraArgs=*/getSwitchValue(1));
579+
580+
// Redirecting exits edges with `shouldRepeat` as 0.
581+
for (Edge exitEdge : exitEdges)
582+
multiplexer.redirectEdge(exitEdge, /*extraArgs=*/getSwitchValue(0));
583+
568584
// Create the new only back edge to the loop header. Branch to the
569585
// exit block otherwise.
570586
Value shouldRepeat = latchBlock->getArguments().back();
@@ -603,14 +619,6 @@ static FailureOr<StructuredLoopProperties> createSingleExitingLatch(
603619
}
604620
}
605621

606-
// Redirecting back edges with `shouldRepeat` as 1.
607-
for (Edge backEdge : backEdges)
608-
multiplexer.redirectEdge(backEdge, /*extraArgs=*/getSwitchValue(1));
609-
610-
// Redirecting exits edges with `shouldRepeat` as 0.
611-
for (Edge exitEdge : exitEdges)
612-
multiplexer.redirectEdge(exitEdge, /*extraArgs=*/getSwitchValue(0));
613-
614622
return StructuredLoopProperties{latchBlock, /*condition=*/shouldRepeat,
615623
exitBlock};
616624
}
@@ -794,13 +802,14 @@ static FailureOr<SmallVector<Block *>> transformCyclesToSCFLoops(
794802
// First turn the cycle into a loop by creating a single entry block if
795803
// needed.
796804
if (edges.entryEdges.size() > 1) {
805+
SmallVector<Edge> edgesToEntryBlocks;
806+
llvm::append_range(edgesToEntryBlocks, edges.entryEdges);
807+
llvm::append_range(edgesToEntryBlocks, edges.backEdges);
808+
797809
EdgeMultiplexer multiplexer = createSingleEntryBlock(
798-
loopHeader->getTerminator()->getLoc(), edges.entryEdges,
810+
loopHeader->getTerminator()->getLoc(), edgesToEntryBlocks,
799811
getSwitchValue, getUndefValue, interface);
800812

801-
for (Edge edge : edges.backEdges)
802-
multiplexer.redirectEdge(edge);
803-
804813
loopHeader = multiplexer.getMultiplexerBlock();
805814
}
806815
cycleBlockSet.insert(loopHeader);
@@ -1140,7 +1149,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
11401149
for (auto &&[block, valueRange] : createdEmptyBlocks) {
11411150
auto builder = OpBuilder::atBlockEnd(block);
11421151
LogicalResult result = interface.createStructuredBranchRegionTerminatorOp(
1143-
structuredCondOp->getLoc(), builder, structuredCondOp, valueRange);
1152+
structuredCondOp->getLoc(), builder, structuredCondOp, nullptr,
1153+
valueRange);
11441154
if (failed(result))
11451155
return failure();
11461156
}
@@ -1153,7 +1163,7 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
11531163
assert(user->getNumSuccessors() == 1);
11541164
auto builder = OpBuilder::atBlockTerminator(user->getBlock());
11551165
LogicalResult result = interface.createStructuredBranchRegionTerminatorOp(
1156-
user->getLoc(), builder, structuredCondOp,
1166+
user->getLoc(), builder, structuredCondOp, user,
11571167
static_cast<OperandRange>(
11581168
getMutableSuccessorOperands(user->getBlock(), 0)));
11591169
if (failed(result))

0 commit comments

Comments
 (0)