[mlir][Transforms] Dialect conversion: Add missing erasure notifications (#145030)
Add missing listener notifications when erasing nested blocks/operations. This commit also moves some of the functionality from `ConversionPatternRewriter` to `ConversionPatternRewriterImpl`. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations in `ConversionPatternRewriter` should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case, `ConversionPatternRewriterImpl` can be bypassed to some degree, and `PatternRewriter::eraseBlock` etc. can be used.) Depends on #145018.
This commit is contained in:
committed by
GitHub
parent
4a4582dd78
commit
0921bfd81d
@@ -274,6 +274,26 @@ struct RewriterState {
|
||||
// IR rewrites
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
|
||||
|
||||
/// Notify the listener that the given block and its contents are being erased.
|
||||
static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
|
||||
for (Operation &op : b)
|
||||
notifyIRErased(listener, op);
|
||||
listener->notifyBlockErased(&b);
|
||||
}
|
||||
|
||||
/// Notify the listener that the given operation and its contents are being
|
||||
/// erased.
|
||||
static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
|
||||
for (Region &r : op.getRegions()) {
|
||||
for (Block &b : r) {
|
||||
notifyIRErased(listener, b);
|
||||
}
|
||||
}
|
||||
listener->notifyOperationErased(&op);
|
||||
}
|
||||
|
||||
/// An IR rewrite that can be committed (upon success) or rolled back (upon
|
||||
/// failure).
|
||||
///
|
||||
@@ -422,17 +442,20 @@ public:
|
||||
}
|
||||
|
||||
void commit(RewriterBase &rewriter) override {
|
||||
// Erase the block.
|
||||
assert(block && "expected block");
|
||||
assert(block->empty() && "expected empty block");
|
||||
|
||||
// Notify the listener that the block is about to be erased.
|
||||
// Notify the listener that the block and its contents are being erased.
|
||||
if (auto *listener =
|
||||
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
|
||||
listener->notifyBlockErased(block);
|
||||
notifyIRErased(listener, *block);
|
||||
}
|
||||
|
||||
void cleanup(RewriterBase &rewriter) override {
|
||||
// Erase the contents of the block.
|
||||
for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
|
||||
rewriter.eraseOp(&op);
|
||||
assert(block->empty() && "expected empty block");
|
||||
|
||||
// Erase the block.
|
||||
block->dropAllDefinedValueUses();
|
||||
delete block;
|
||||
@@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
|
||||
if (getConfig().unlegalizedOps)
|
||||
getConfig().unlegalizedOps->erase(op);
|
||||
|
||||
// Notify the listener that the operation (and its nested operations) was
|
||||
// erased.
|
||||
if (listener) {
|
||||
op->walk<WalkOrder::PostOrder>(
|
||||
[&](Operation *op) { listener->notifyOperationErased(op); });
|
||||
}
|
||||
// Notify the listener that the operation and its contents are being erased.
|
||||
if (listener)
|
||||
notifyIRErased(listener, *op);
|
||||
|
||||
// Do not erase the operation yet. It may still be referenced in `mapping`.
|
||||
// Just unlink it for now and erase it during cleanup.
|
||||
@@ -1605,6 +1625,8 @@ void ConversionPatternRewriterImpl::replaceOp(
|
||||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
|
||||
assert(!wasOpReplaced(block->getParentOp()) &&
|
||||
"attempting to erase a block within a replaced/erased op");
|
||||
appendRewrite<EraseBlockRewrite>(block);
|
||||
|
||||
// Unlink the block from its parent region. The block is kept in the rewrite
|
||||
@@ -1612,6 +1634,9 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
|
||||
// allows us to keep the operations in the block live and undo the removal by
|
||||
// re-inserting the block.
|
||||
block->getParent()->getBlocks().remove(block);
|
||||
|
||||
// Mark all nested ops as erased.
|
||||
block->walk([&](Operation *op) { replacedOps.insert(op); });
|
||||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::notifyBlockInserted(
|
||||
@@ -1709,13 +1734,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::eraseBlock(Block *block) {
|
||||
assert(!impl->wasOpReplaced(block->getParentOp()) &&
|
||||
"attempting to erase a block within a replaced/erased op");
|
||||
|
||||
// Mark all ops for erasure.
|
||||
for (Operation &op : *block)
|
||||
eraseOp(&op);
|
||||
|
||||
impl->eraseBlock(block);
|
||||
}
|
||||
|
||||
|
||||
@@ -461,12 +461,26 @@ func.func @convert_detached_signature() {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: notifyOperationReplaced: test.erase_op
|
||||
// CHECK: notifyOperationErased: test.dummy_op_lvl_2
|
||||
// CHECK: notifyBlockErased
|
||||
// CHECK: notifyOperationErased: test.dummy_op_lvl_1
|
||||
// CHECK: notifyBlockErased
|
||||
// CHECK: notifyOperationErased: test.erase_op
|
||||
// CHECK: notifyOperationInserted: test.valid, was unlinked
|
||||
// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
|
||||
// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
|
||||
|
||||
// CHECK-LABEL: func @circular_mapping()
|
||||
// CHECK-NEXT: "test.valid"() : () -> ()
|
||||
func.func @circular_mapping() {
|
||||
// Regression test that used to crash due to circular
|
||||
// unrealized_conversion_cast ops.
|
||||
%0 = "test.erase_op"() : () -> (i64)
|
||||
// unrealized_conversion_cast ops.
|
||||
%0 = "test.erase_op"() ({
|
||||
"test.dummy_op_lvl_1"() ({
|
||||
"test.dummy_op_lvl_2"() : () -> ()
|
||||
}) : () -> ()
|
||||
}): () -> (i64)
|
||||
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user