[mlir] check whether region and block visitors are interrupted
The visitor functions for `Region` and `Block` types did not always check the value returned by recursive calls. This caused the top-level visitor invocation to return `WalkResult::advance()` even if one or more recursive invocations returned `WalkResult::interrupt()`. This patch fixes the problem by check if any recursive call is interrupted, and if so, return `WalkResult::interrupt()`. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D129718
This commit is contained in:
@@ -114,7 +114,8 @@ WalkResult detail::walk(Operation *op,
|
||||
}
|
||||
for (auto &block : region) {
|
||||
for (auto &nestedOp : block)
|
||||
walk(&nestedOp, callback, order);
|
||||
if (walk(&nestedOp, callback, order).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
if (order == WalkOrder::PostOrder) {
|
||||
if (callback(®ion).wasInterrupted())
|
||||
@@ -140,7 +141,8 @@ WalkResult detail::walk(Operation *op,
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
for (auto &nestedOp : block)
|
||||
walk(&nestedOp, callback, order);
|
||||
if (walk(&nestedOp, callback, order).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
if (order == WalkOrder::PostOrder) {
|
||||
if (callback(&block).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
|
||||
9
mlir/test/IR/generic-block-visitors-interrupt.mlir
Normal file
9
mlir/test/IR/generic-block-visitors-interrupt.mlir
Normal file
@@ -0,0 +1,9 @@
|
||||
// RUN: mlir-opt -test-generic-ir-block-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||
|
||||
func.func @main(%arg0: f32) -> f32 {
|
||||
%v1 = "foo"() {interrupt = true} : () -> f32
|
||||
%v2 = arith.addf %v1, %arg0 : f32
|
||||
return %v2 : f32
|
||||
}
|
||||
|
||||
// CHECK: step 0 walk was interrupted
|
||||
9
mlir/test/IR/generic-region-visitors-interrupt.mlir
Normal file
9
mlir/test/IR/generic-region-visitors-interrupt.mlir
Normal file
@@ -0,0 +1,9 @@
|
||||
// RUN: mlir-opt -test-generic-ir-region-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
|
||||
|
||||
func.func @main(%arg0: f32) -> f32 {
|
||||
%v1 = "foo"() {interrupt = true} : () -> f32
|
||||
%v2 = arith.addf %v1, %arg0 : f32
|
||||
return %v2 : f32
|
||||
}
|
||||
|
||||
// CHECK: step 0 walk was interrupted
|
||||
@@ -113,6 +113,73 @@ struct TestGenericIRVisitorInterruptPass
|
||||
}
|
||||
};
|
||||
|
||||
struct TestGenericIRBlockVisitorInterruptPass
|
||||
: public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
||||
TestGenericIRBlockVisitorInterruptPass)
|
||||
|
||||
StringRef getArgument() const final {
|
||||
return "test-generic-ir-block-visitors-interrupt";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test generic IR visitors with interrupts, starting with Blocks.";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
int stepNo = 0;
|
||||
|
||||
auto walker = [&](Block *block) {
|
||||
for (Operation &op : *block)
|
||||
for (OpResult result : op.getResults())
|
||||
if (Operation *definingOp = result.getDefiningOp())
|
||||
if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
llvm::outs() << "step " << stepNo++ << "\n";
|
||||
return WalkResult::advance();
|
||||
};
|
||||
|
||||
auto result = getOperation()->walk(walker);
|
||||
if (result.wasInterrupted())
|
||||
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
|
||||
}
|
||||
};
|
||||
|
||||
struct TestGenericIRRegionVisitorInterruptPass
|
||||
: public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
||||
TestGenericIRRegionVisitorInterruptPass)
|
||||
|
||||
StringRef getArgument() const final {
|
||||
return "test-generic-ir-region-visitors-interrupt";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test generic IR visitors with interrupts, starting with Regions.";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
int stepNo = 0;
|
||||
|
||||
auto walker = [&](Region *region) {
|
||||
for (Block &block : *region)
|
||||
for (Operation &op : block)
|
||||
for (OpResult result : op.getResults())
|
||||
if (Operation *definingOp = result.getDefiningOp())
|
||||
if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
llvm::outs() << "step " << stepNo++ << "\n";
|
||||
return WalkResult::advance();
|
||||
};
|
||||
|
||||
auto result = getOperation()->walk(walker);
|
||||
if (result.wasInterrupted())
|
||||
llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
@@ -120,6 +187,8 @@ namespace test {
|
||||
void registerTestGenericIRVisitorsPass() {
|
||||
PassRegistration<TestGenericIRVisitorPass>();
|
||||
PassRegistration<TestGenericIRVisitorInterruptPass>();
|
||||
PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
|
||||
PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
||||
Reference in New Issue
Block a user