From dfe09cc621ec11f36ec2e36f4fd01fce8ceec87f Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 16 Oct 2019 09:50:28 -0700 Subject: [PATCH] Add support for PatternRewriter::eraseOp. This hook is useful when an operation is known to be dead, and no replacement values make sense. PiperOrigin-RevId: 275052756 --- mlir/include/mlir/IR/PatternMatch.h | 3 ++ .../mlir/Transforms/DialectConversion.h | 5 ++++ .../LoopToStandard/ConvertLoopToStandard.cpp | 2 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 4 +-- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 2 +- .../Linalg/Transforms/LowerToLoops.cpp | 4 +-- mlir/lib/Dialect/StandardOps/Ops.cpp | 14 ++++----- mlir/lib/IR/PatternMatch.cpp | 8 +++++ mlir/lib/Transforms/DialectConversion.cpp | 29 ++++++++++++------- mlir/lib/Transforms/LowerAffine.cpp | 4 +-- mlir/lib/Transforms/LowerVectorTransfers.cpp | 2 +- mlir/test/lib/TestDialect/TestDialect.cpp | 2 +- mlir/test/lib/TestDialect/TestPatterns.cpp | 6 ++-- 13 files changed, 54 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 2fc10694c9c4..8e97b58a405d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -356,6 +356,9 @@ public: valuesToRemoveIfDead); } + /// This method erases an operation that is known to have no uses. + virtual void eraseOp(Operation *op); + /// Split the operations starting at "before" (inclusive) out of the given /// block into a new block, and return it. virtual Block *splitBlock(Block *block, Block::iterator before) { diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index d94146dcb1b3..660874dc85a8 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -262,6 +262,11 @@ public: ArrayRef valuesToRemoveIfDead) override; using PatternRewriter::replaceOp; + /// PatternRewriter hook for erasing a dead operation. The uses of this + /// operation *must* be made dead by the end of the conversion process, + /// otherwise an assert will be issued. + void eraseOp(Operation *op) override; + /// PatternRewriter hook for splitting a block into two parts. Block *splitBlock(Block *block, Block::iterator before) override; diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp index d70a054e3830..556a49342338 100644 --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -156,7 +156,7 @@ struct TerminatorLowering : public OpRewritePattern { PatternMatchResult matchAndRewrite(TerminatorOp op, PatternRewriter &rewriter) const override { - rewriter.replaceOp(op, {}); + rewriter.eraseOp(op); return matchSuccess(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 4b7dec7f3c0f..15f61ab9ce81 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -363,7 +363,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern { } } - rewriter.replaceOp(op, llvm::None); + rewriter.eraseOp(op); return matchSuccess(); } }; @@ -474,7 +474,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) - return rewriter.replaceOp(op, llvm::None), this->matchSuccess(); + return rewriter.eraseOp(op), this->matchSuccess(); if (numResults == 1) return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), this->matchSuccess(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 32512d0be9e5..a0955d50523c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -280,7 +280,7 @@ public: Value *base = extractvalue(voidPtrTy, adaptor.buffer(), rewriter.getI64ArrayAttr(kBasePtrPosInBuffer)); llvm_call(ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), base); - rewriter.replaceOp(op, llvm::None); + rewriter.eraseOp(op); return matchSuccess(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp index 7854df8d332b..e6070a6a9e0d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp @@ -308,7 +308,7 @@ public: if (!invertedMap) { LinalgScopedEmitter::emitScalarImplementation({}, linalgOp, folder); - rewriter.replaceOp(op, {}); + rewriter.eraseOp(op); return matchSuccess(); } @@ -341,7 +341,7 @@ public: }); }); // clang-format on - rewriter.replaceOp(op, {}); + rewriter.eraseOp(op); return matchSuccess(); } diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 443aa64c5260..7177cfe7dff0 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -455,13 +455,11 @@ struct SimplifyDeadAlloc : public OpRewritePattern { PatternMatchResult matchAndRewrite(AllocOp alloc, PatternRewriter &rewriter) const override { - // Check if the alloc'ed value has any uses. - if (!alloc.use_empty()) - return matchFailure(); - - // If it doesn't, we can eliminate it. - alloc.erase(); - return matchSuccess(); + if (alloc.use_empty()) { + rewriter.eraseOp(alloc); + return matchSuccess(); + } + return matchFailure(); } }; } // end anonymous namespace. @@ -1296,7 +1294,7 @@ struct SimplifyDeadDealloc : public OpRewritePattern { return matchFailure(); // Erase the dealloc operation. - rewriter.replaceOp(dealloc, llvm::None); + rewriter.eraseOp(dealloc); return matchSuccess(); } }; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 1b2e4ee2b1c9..1f9b9b060aa9 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -101,6 +101,14 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef newValues, // the notifyOperationRemoved hook in the process. } +/// This method erases an operation that is known to have no uses. The uses of +/// the given operation *must* be known to be dead. +void PatternRewriter::eraseOp(Operation *op) { + assert(op->use_empty() && "expected 'op' to have no uses"); + notifyOperationRemoved(op); + op->erase(); +} + /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp void PatternRewriter::replaceOpWithResultsOfAnotherOp( diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 0007feb4ccd3..4ab10676f6cf 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -583,9 +583,11 @@ void ConversionPatternRewriterImpl::discardRewrites() { void ConversionPatternRewriterImpl::applyRewrites() { // Apply all of the rewrites replacements requested during conversion. for (auto &repl : replacements) { - for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) - repl.op->getResult(i)->replaceAllUsesWith( - mapping.lookupOrDefault(repl.newValues[i])); + for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) { + if (auto *newValue = repl.newValues[i]) + repl.op->getResult(i)->replaceAllUsesWith( + mapping.lookupOrDefault(newValue)); + } // If this operation defines any regions, drop any pending argument // rewrites. @@ -637,12 +639,9 @@ void ConversionPatternRewriterImpl::replaceOp( assert(newValues.size() == op->getNumResults()); // Create mappings for each of the new result values. - for (unsigned i = 0, e = newValues.size(); i < e; ++i) { - assert((newValues[i] || op->getResult(i)->use_empty()) && - "result value has remaining uses that must be replaced"); - if (newValues[i]) - mapping.map(op->getResult(i), newValues[i]); - } + for (unsigned i = 0, e = newValues.size(); i < e; ++i) + if (auto *repl = newValues[i]) + mapping.map(op->getResult(i), repl); // Record the requested operation replacement. replacements.emplace_back(op, newValues); @@ -718,6 +717,16 @@ void ConversionPatternRewriter::replaceOp( impl->replaceOp(op, newValues, valuesToRemoveIfDead); } +/// PatternRewriter hook for erasing a dead operation. The uses of this +/// operation *must* be made dead by the end of the conversion process, +/// otherwise an assert will be issued. +void ConversionPatternRewriter::eraseOp(Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName() + << "\n"); + SmallVector nullRepls(op->getNumResults(), nullptr); + impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None); +} + /// Apply a signature conversion to the entry block of the given region. void ConversionPatternRewriter::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion) { @@ -1397,7 +1406,7 @@ struct FuncOpSignatureConversion : public ConversionPattern { // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); - rewriter.replaceOp(op, llvm::None); + rewriter.eraseOp(op); return matchSuccess(); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 72d52b9a4f56..b3e811b7123d 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -319,7 +319,7 @@ public: auto f = rewriter.create(loc, lowerBound, upperBound, step); f.region().getBlocks().clear(); rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end()); - rewriter.replaceOp(op, {}); + rewriter.eraseOp(op); return matchSuccess(); } }; @@ -370,7 +370,7 @@ public: } // Ok, we're done! - rewriter.replaceOp(op, {}); + rewriter.eraseOp(op); return matchSuccess(); } }; diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index 09d3a4672987..5ffee64f93d4 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -355,7 +355,7 @@ VectorTransferRewriter::matchAndRewrite( }); (dealloc(tmp)); // vexing parse... - rewriter.replaceOp(op, llvm::None); + rewriter.eraseOp(op); return matchSuccess(); } diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index c8db7967fe08..ee8325fd13ec 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -220,7 +220,7 @@ struct TestRemoveOpWithInnerOps PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op, PatternRewriter &rewriter) const override { - rewriter.replaceOp(op, llvm::None); + rewriter.eraseOp(op); return matchSuccess(); } }; diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 2dde6a37675f..696a98761051 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -115,7 +115,7 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern { parentRegion.end()); // Drop this operation. - rewriter.replaceOp(op, llvm::None); + rewriter.eraseOp(op); return matchSuccess(); } }; @@ -139,7 +139,7 @@ struct TestRegionRewriteUndo : public RewritePattern { rewriter.create(op->getLoc(), ArrayRef()); // Drop this operation. - rewriter.replaceOp(op, llvm::None); + rewriter.eraseOp(op); return matchSuccess(); } }; @@ -153,7 +153,7 @@ struct TestDropOp : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOp(op, llvm::None); + rewriter.eraseOp(op); return matchSuccess(); } };