Skip to content

Commit b840d29

Browse files
[mlir][IR] Send notifications for cloneRegionBefore (#66871)
Similar to `OpBuilder::clone`, operation/block insertion notifications should be sent when cloning the contents of a region. E.g., this is to ensure that the newly created operations are put on the worklist of the greedy pattern rewriter driver. Also move `cloneRegionBefore` from `RewriterBase` to `OpBuilder`. It only creates new IR, so it should be part of the builder API (like `clone(Operation &)`). The function does not have to be virtual. Now that notifications are properly sent, the override in the dialect conversion is no longer needed.
1 parent c07fcd4 commit b840d29

File tree

9 files changed

+107
-65
lines changed

9 files changed

+107
-65
lines changed

mlir/include/mlir/IR/Builders.h

+10
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,16 @@ class OpBuilder : public Builder {
583583
return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
584584
}
585585

586+
/// Clone the blocks that belong to "region" before the given position in
587+
/// another region "parent". The two regions must be different. The caller is
588+
/// responsible for creating or updating the operation transferring flow of
589+
/// control to the region and passing it the correct block arguments.
590+
void cloneRegionBefore(Region &region, Region &parent,
591+
Region::iterator before, IRMapping &mapping);
592+
void cloneRegionBefore(Region &region, Region &parent,
593+
Region::iterator before);
594+
void cloneRegionBefore(Region &region, Block *before);
595+
586596
protected:
587597
/// The optional listener for events of this builder.
588598
Listener *listener;

mlir/include/mlir/IR/PatternMatch.h

-10
Original file line numberDiff line numberDiff line change
@@ -500,16 +500,6 @@ class RewriterBase : public OpBuilder {
500500
Region::iterator before);
501501
void inlineRegionBefore(Region &region, Block *before);
502502

503-
/// Clone the blocks that belong to "region" before the given position in
504-
/// another region "parent". The two regions must be different. The caller is
505-
/// responsible for creating or updating the operation transferring flow of
506-
/// control to the region and passing it the correct block arguments.
507-
virtual void cloneRegionBefore(Region &region, Region &parent,
508-
Region::iterator before, IRMapping &mapping);
509-
void cloneRegionBefore(Region &region, Region &parent,
510-
Region::iterator before);
511-
void cloneRegionBefore(Region &region, Block *before);
512-
513503
/// This method replaces the uses of the results of `op` with the values in
514504
/// `newValues` when the provided `functor` returns true for a specific use.
515505
/// The number of values in `newValues` is required to match the number of

mlir/include/mlir/Transforms/DialectConversion.h

-8
Original file line numberDiff line numberDiff line change
@@ -724,14 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
724724
ValueRange argValues = std::nullopt) override;
725725
using PatternRewriter::inlineBlockBefore;
726726

727-
/// PatternRewriter hook for cloning blocks of one region into another. The
728-
/// given region to clone *must* not have been modified as part of conversion
729-
/// yet, i.e. it must be within an operation that is either in the process of
730-
/// conversion, or has not yet been converted.
731-
void cloneRegionBefore(Region &region, Region &parent,
732-
Region::iterator before, IRMapping &mapping) override;
733-
using PatternRewriter::cloneRegionBefore;
734-
735727
/// PatternRewriter hook for inserting a new operation.
736728
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
737729

mlir/lib/IR/Builders.cpp

+42-10
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,16 @@ LogicalResult OpBuilder::tryFold(Operation *op,
522522
return success();
523523
}
524524

525+
/// Helper function that sends block insertion notifications for every block
526+
/// that is directly nested in the given op.
527+
static void notifyBlockInsertions(Operation *op,
528+
OpBuilder::Listener *listener) {
529+
for (Region &r : op->getRegions())
530+
for (Block &b : r.getBlocks())
531+
listener->notifyBlockInserted(&b, /*previous=*/nullptr,
532+
/*previousIt=*/{});
533+
}
534+
525535
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
526536
Operation *newOp = op.clone(mapper);
527537
newOp = insert(newOp);
@@ -530,20 +540,12 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
530540
// itself. But if `newOp` has any regions, we need to notify the listener
531541
// about any ops that got inserted inside those regions as part of cloning.
532542
if (listener) {
533-
// Helper function that sends block insertion notifications for every block
534-
// within the given op.
535-
auto notifyBlockInsertions = [&](Operation *op) {
536-
for (Region &r : op->getRegions())
537-
for (Block &b : r.getBlocks())
538-
listener->notifyBlockInserted(&b, /*previous=*/nullptr,
539-
/*previousIt=*/{});
540-
};
541543
// The `insert` call above notifies about op insertion, but not about block
542544
// insertion.
543-
notifyBlockInsertions(newOp);
545+
notifyBlockInsertions(newOp, listener);
544546
auto walkFn = [&](Operation *walkedOp) {
545547
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
546-
notifyBlockInsertions(walkedOp);
548+
notifyBlockInsertions(walkedOp, listener);
547549
};
548550
for (Region &region : newOp->getRegions())
549551
region.walk<WalkOrder::PreOrder>(walkFn);
@@ -556,3 +558,33 @@ Operation *OpBuilder::clone(Operation &op) {
556558
IRMapping mapper;
557559
return clone(op, mapper);
558560
}
561+
562+
void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
563+
Region::iterator before, IRMapping &mapping) {
564+
region.cloneInto(&parent, before, mapping);
565+
566+
// Fast path: If no listener is attached, there is no more work to do.
567+
if (!listener)
568+
return;
569+
570+
// Notify about op/block insertion.
571+
for (auto it = mapping.lookup(&region.front())->getIterator(); it != before;
572+
++it) {
573+
listener->notifyBlockInserted(&*it, /*previous=*/nullptr,
574+
/*previousIt=*/{});
575+
it->walk<WalkOrder::PreOrder>([&](Operation *walkedOp) {
576+
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
577+
notifyBlockInsertions(walkedOp, listener);
578+
});
579+
}
580+
}
581+
582+
void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
583+
Region::iterator before) {
584+
IRMapping mapping;
585+
cloneRegionBefore(region, parent, before, mapping);
586+
}
587+
588+
void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
589+
cloneRegionBefore(region, *before->getParent(), before->getIterator());
590+
}

mlir/lib/IR/PatternMatch.cpp

-18
Original file line numberDiff line numberDiff line change
@@ -384,24 +384,6 @@ void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
384384
inlineRegionBefore(region, *before->getParent(), before->getIterator());
385385
}
386386

387-
/// Clone the blocks that belong to "region" before the given position in
388-
/// another region "parent". The two regions must be different. The caller is
389-
/// responsible for creating or updating the operation transferring flow of
390-
/// control to the region and passing it the correct block arguments.
391-
void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
392-
Region::iterator before,
393-
IRMapping &mapping) {
394-
region.cloneInto(&parent, before, mapping);
395-
}
396-
void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
397-
Region::iterator before) {
398-
IRMapping mapping;
399-
cloneRegionBefore(region, parent, before, mapping);
400-
}
401-
void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
402-
cloneRegionBefore(region, *before->getParent(), before->getIterator());
403-
}
404-
405387
void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
406388
moveBlockBefore(block, anotherBlock->getParent(),
407389
anotherBlock->getIterator());

mlir/lib/Transforms/Utils/DialectConversion.cpp

-17
Original file line numberDiff line numberDiff line change
@@ -1573,23 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
15731573
eraseBlock(source);
15741574
}
15751575

1576-
void ConversionPatternRewriter::cloneRegionBefore(Region &region,
1577-
Region &parent,
1578-
Region::iterator before,
1579-
IRMapping &mapping) {
1580-
if (region.empty())
1581-
return;
1582-
1583-
PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
1584-
1585-
for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
1586-
Block *cloned = mapping.lookup(&b);
1587-
impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
1588-
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
1589-
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
1590-
}
1591-
}
1592-
15931576
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
15941577
InsertPoint previous) {
15951578
assert(!previous.isSet() && "expected newly created op");

mlir/test/Transforms/test-legalizer-full.mlir

+3-1
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,11 @@ builtin.module {
110110
// expected-error@+1 {{failed to legalize operation 'test.region'}}
111111
"test.region"() ({
112112
^bb1(%i0: i64):
113-
cf.br ^bb2(%i0 : i64)
113+
cf.br ^bb3(%i0 : i64)
114114
^bb2(%i1: i64):
115115
"test.invalid"(%i1) : (i64) -> ()
116+
^bb3(%i2: i64):
117+
cf.br ^bb2(%i2 : i64)
116118
}) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> ()
117119

118120
"test.return"() : () -> ()

mlir/test/Transforms/test-strict-pattern-driver.mlir

+31
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,34 @@ func.func @clone_op() {
323323
}) : () -> ()
324324
return
325325
}
326+
327+
328+
// -----
329+
330+
// CHECK-AN: notifyBlockInserted into func.func: was unlinked
331+
// CHECK-AN: notifyOperationInserted: test.op_1, was unlinked
332+
// CHECK-AN: notifyBlockInserted into func.func: was unlinked
333+
// CHECK-AN: notifyOperationInserted: test.op_2, was unlinked
334+
// CHECK-AN: notifyBlockInserted into test.op_2: was unlinked
335+
// CHECK-AN: notifyOperationInserted: test.op_3, was unlinked
336+
// CHECK-AN: notifyOperationInserted: test.op_4, was unlinked
337+
// CHECK-AN-LABEL: func @test_clone_region_before(
338+
// CHECK-AN: "test.op_1"() : () -> ()
339+
// CHECK-AN: ^{{.*}}:
340+
// CHECK-AN: "test.op_2"() ({
341+
// CHECK-AN: "test.op_3"() : () -> ()
342+
// CHECK-AN: }) : () -> ()
343+
// CHECK-AN: "test.op_4"() : () -> ()
344+
// CHECK-AN: ^{{.*}}:
345+
// CHECK-AN: "test.clone_region_before"() ({
346+
func.func @test_clone_region_before() {
347+
"test.clone_region_before"() ({
348+
"test.op_1"() : () -> ()
349+
^bb0:
350+
"test.op_2"() ({
351+
"test.op_3"() : () -> ()
352+
}) : () -> ()
353+
"test.op_4"() : () -> ()
354+
}) : () -> ()
355+
return
356+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

+21-1
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,24 @@ struct CloneOp : public RewritePattern {
267267
}
268268
};
269269

270+
/// This pattern clones regions of "test.clone_region_before" ops before the
271+
/// parent block.
272+
struct CloneRegionBeforeOp : public RewritePattern {
273+
CloneRegionBeforeOp(MLIRContext *context)
274+
: RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
275+
276+
LogicalResult matchAndRewrite(Operation *op,
277+
PatternRewriter &rewriter) const override {
278+
// Do not clone already cloned ops to avoid going into an infinite loop.
279+
if (op->hasAttr("was_cloned"))
280+
return failure();
281+
for (Region &r : op->getRegions())
282+
rewriter.cloneRegionBefore(r, op->getBlock());
283+
op->setAttr("was_cloned", rewriter.getUnitAttr());
284+
return success();
285+
}
286+
};
287+
270288
struct TestPatternDriver
271289
: public PassWrapper<TestPatternDriver, OperationPass<>> {
272290
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -358,6 +376,7 @@ struct TestStrictPatternDriver
358376
// clang-format off
359377
ChangeBlockOp,
360378
CloneOp,
379+
CloneRegionBeforeOp,
361380
EraseOp,
362381
ImplicitChangeOp,
363382
InlineBlocksIntoParent,
@@ -374,7 +393,8 @@ struct TestStrictPatternDriver
374393
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
375394
opName == "test.move_before_parent_op" ||
376395
opName == "test.inline_blocks_into_parent" ||
377-
opName == "test.split_block_here" || opName == "test.clone_me") {
396+
opName == "test.split_block_here" || opName == "test.clone_me" ||
397+
opName == "test.clone_region_before") {
378398
ops.push_back(op);
379399
}
380400
});

0 commit comments

Comments
 (0)