Skip to content

Commit f2b94bd

Browse files
committed
[mlir] check whether region and block visitors are interrupted
The visitor functions for `Region` and `Block` types did not always check the value returned by recursive calls. This caused the top-level visitor invocation to return `WalkResult::advance()` even if one or more recursive invocations returned `WalkResult::interrupt()`. This patch fixes the problem by check if any recursive call is interrupted, and if so, return `WalkResult::interrupt()`. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D129718
1 parent bb957a8 commit f2b94bd

File tree

4 files changed

+91
-2
lines changed

4 files changed

+91
-2
lines changed

Diff for: mlir/lib/IR/Visitors.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ WalkResult detail::walk(Operation *op,
114114
}
115115
for (auto &block : region) {
116116
for (auto &nestedOp : block)
117-
walk(&nestedOp, callback, order);
117+
if (walk(&nestedOp, callback, order).wasInterrupted())
118+
return WalkResult::interrupt();
118119
}
119120
if (order == WalkOrder::PostOrder) {
120121
if (callback(&region).wasInterrupted())
@@ -140,7 +141,8 @@ WalkResult detail::walk(Operation *op,
140141
return WalkResult::interrupt();
141142
}
142143
for (auto &nestedOp : block)
143-
walk(&nestedOp, callback, order);
144+
if (walk(&nestedOp, callback, order).wasInterrupted())
145+
return WalkResult::interrupt();
144146
if (order == WalkOrder::PostOrder) {
145147
if (callback(&block).wasInterrupted())
146148
return WalkResult::interrupt();

Diff for: mlir/test/IR/generic-block-visitors-interrupt.mlir

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt -test-generic-ir-block-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
2+
3+
func.func @main(%arg0: f32) -> f32 {
4+
%v1 = "foo"() {interrupt = true} : () -> f32
5+
%v2 = arith.addf %v1, %arg0 : f32
6+
return %v2 : f32
7+
}
8+
9+
// CHECK: step 0 walk was interrupted

Diff for: mlir/test/IR/generic-region-visitors-interrupt.mlir

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt -test-generic-ir-region-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
2+
3+
func.func @main(%arg0: f32) -> f32 {
4+
%v1 = "foo"() {interrupt = true} : () -> f32
5+
%v2 = arith.addf %v1, %arg0 : f32
6+
return %v2 : f32
7+
}
8+
9+
// CHECK: step 0 walk was interrupted

Diff for: mlir/test/lib/IR/TestVisitorsGeneric.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,82 @@ struct TestGenericIRVisitorInterruptPass
113113
}
114114
};
115115

116+
struct TestGenericIRBlockVisitorInterruptPass
117+
: public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
118+
OperationPass<ModuleOp>> {
119+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
120+
TestGenericIRBlockVisitorInterruptPass)
121+
122+
StringRef getArgument() const final {
123+
return "test-generic-ir-block-visitors-interrupt";
124+
}
125+
StringRef getDescription() const final {
126+
return "Test generic IR visitors with interrupts, starting with Blocks.";
127+
}
128+
129+
void runOnOperation() override {
130+
int stepNo = 0;
131+
132+
auto walker = [&](Block *block) {
133+
for (Operation &op : *block)
134+
for (OpResult result : op.getResults())
135+
if (Operation *definingOp = result.getDefiningOp())
136+
if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
137+
return WalkResult::interrupt();
138+
139+
llvm::outs() << "step " << stepNo++ << "\n";
140+
return WalkResult::advance();
141+
};
142+
143+
auto result = getOperation()->walk(walker);
144+
if (result.wasInterrupted())
145+
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
146+
}
147+
};
148+
149+
struct TestGenericIRRegionVisitorInterruptPass
150+
: public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
151+
OperationPass<ModuleOp>> {
152+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
153+
TestGenericIRRegionVisitorInterruptPass)
154+
155+
StringRef getArgument() const final {
156+
return "test-generic-ir-region-visitors-interrupt";
157+
}
158+
StringRef getDescription() const final {
159+
return "Test generic IR visitors with interrupts, starting with Regions.";
160+
}
161+
162+
void runOnOperation() override {
163+
int stepNo = 0;
164+
165+
auto walker = [&](Region *region) {
166+
for (Block &block : *region)
167+
for (Operation &op : block)
168+
for (OpResult result : op.getResults())
169+
if (Operation *definingOp = result.getDefiningOp())
170+
if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
171+
return WalkResult::interrupt();
172+
173+
llvm::outs() << "step " << stepNo++ << "\n";
174+
return WalkResult::advance();
175+
};
176+
177+
auto result = getOperation()->walk(walker);
178+
if (result.wasInterrupted())
179+
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
180+
}
181+
};
182+
116183
} // namespace
117184

118185
namespace mlir {
119186
namespace test {
120187
void registerTestGenericIRVisitorsPass() {
121188
PassRegistration<TestGenericIRVisitorPass>();
122189
PassRegistration<TestGenericIRVisitorInterruptPass>();
190+
PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
191+
PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
123192
}
124193

125194
} // namespace test

0 commit comments

Comments
 (0)