[mlir] check whether region and block visitors are interrupted

The visitor functions for `Region` and `Block` types did not always
check the value returned by recursive calls.  This caused the top-level
visitor invocation to return `WalkResult::advance()` even if one or more
recursive invocations returned `WalkResult::interrupt()`.  This patch
fixes the problem by check if any recursive call is interrupted, and if
so, return `WalkResult::interrupt()`.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D129718
This commit is contained in:
Ashay Rane
2022-07-13 18:20:56 -07:00
parent bb957a8d52
commit f2b94bd7ea
4 changed files with 91 additions and 2 deletions

View File

@@ -114,7 +114,8 @@ WalkResult detail::walk(Operation *op,
}
for (auto &block : region) {
for (auto &nestedOp : block)
walk(&nestedOp, callback, order);
if (walk(&nestedOp, callback, order).wasInterrupted())
return WalkResult::interrupt();
}
if (order == WalkOrder::PostOrder) {
if (callback(&region).wasInterrupted())
@@ -140,7 +141,8 @@ WalkResult detail::walk(Operation *op,
return WalkResult::interrupt();
}
for (auto &nestedOp : block)
walk(&nestedOp, callback, order);
if (walk(&nestedOp, callback, order).wasInterrupted())
return WalkResult::interrupt();
if (order == WalkOrder::PostOrder) {
if (callback(&block).wasInterrupted())
return WalkResult::interrupt();

View File

@@ -0,0 +1,9 @@
// RUN: mlir-opt -test-generic-ir-block-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
func.func @main(%arg0: f32) -> f32 {
%v1 = "foo"() {interrupt = true} : () -> f32
%v2 = arith.addf %v1, %arg0 : f32
return %v2 : f32
}
// CHECK: step 0 walk was interrupted

View File

@@ -0,0 +1,9 @@
// RUN: mlir-opt -test-generic-ir-region-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
func.func @main(%arg0: f32) -> f32 {
%v1 = "foo"() {interrupt = true} : () -> f32
%v2 = arith.addf %v1, %arg0 : f32
return %v2 : f32
}
// CHECK: step 0 walk was interrupted

View File

@@ -113,6 +113,73 @@ struct TestGenericIRVisitorInterruptPass
}
};
struct TestGenericIRBlockVisitorInterruptPass
: public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestGenericIRBlockVisitorInterruptPass)
StringRef getArgument() const final {
return "test-generic-ir-block-visitors-interrupt";
}
StringRef getDescription() const final {
return "Test generic IR visitors with interrupts, starting with Blocks.";
}
void runOnOperation() override {
int stepNo = 0;
auto walker = [&](Block *block) {
for (Operation &op : *block)
for (OpResult result : op.getResults())
if (Operation *definingOp = result.getDefiningOp())
if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
return WalkResult::interrupt();
llvm::outs() << "step " << stepNo++ << "\n";
return WalkResult::advance();
};
auto result = getOperation()->walk(walker);
if (result.wasInterrupted())
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
}
};
struct TestGenericIRRegionVisitorInterruptPass
: public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestGenericIRRegionVisitorInterruptPass)
StringRef getArgument() const final {
return "test-generic-ir-region-visitors-interrupt";
}
StringRef getDescription() const final {
return "Test generic IR visitors with interrupts, starting with Regions.";
}
void runOnOperation() override {
int stepNo = 0;
auto walker = [&](Region *region) {
for (Block &block : *region)
for (Operation &op : block)
for (OpResult result : op.getResults())
if (Operation *definingOp = result.getDefiningOp())
if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
return WalkResult::interrupt();
llvm::outs() << "step " << stepNo++ << "\n";
return WalkResult::advance();
};
auto result = getOperation()->walk(walker);
if (result.wasInterrupted())
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
}
};
} // namespace
namespace mlir {
@@ -120,6 +187,8 @@ namespace test {
void registerTestGenericIRVisitorsPass() {
PassRegistration<TestGenericIRVisitorPass>();
PassRegistration<TestGenericIRVisitorInterruptPass>();
PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
}
} // namespace test