[mlir][Transforms][NFC] Dialect Conversion: Keep unresolvedMaterializations up to date (#144254)
`unresolvedMaterializations` is a mapping from `UnrealizedConversionCastOp` to `UnresolvedMaterializationRewrite`. This mapping is needed to find the correct type converter for an unresolved materialization. With this commit, `unresolvedMaterializations` is updated immediately when an op is being erased. This also cleans up the code base a bit: `SingleEraseRewriter` is now used only during the "cleanup" phase and no longer needed as a field of `ConversionRewriterImpl`. This commit is in preparation of the One-Shot Dialect Conversion refactoring: `allowPatternRollback = false` will in the future trigger immediate materialization of all IR changes.
This commit is contained in:
committed by
GitHub
parent
a1c2a71293
commit
66580f77b8
@@ -848,7 +848,7 @@ namespace detail {
|
||||
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
|
||||
const ConversionConfig &config)
|
||||
: context(ctx), eraseRewriter(ctx), config(config) {}
|
||||
: context(ctx), config(config) {}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// State Management
|
||||
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
/// no new IR is created between calls to `eraseOp`/`eraseBlock`.
|
||||
struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
|
||||
public:
|
||||
SingleEraseRewriter(MLIRContext *context)
|
||||
: RewriterBase(context, /*listener=*/this) {}
|
||||
SingleEraseRewriter(
|
||||
MLIRContext *context,
|
||||
std::function<void(Operation *)> opErasedCallback = nullptr)
|
||||
: RewriterBase(context, /*listener=*/this),
|
||||
opErasedCallback(opErasedCallback) {}
|
||||
|
||||
/// Erase the given op (unless it was already erased).
|
||||
void eraseOp(Operation *op) override {
|
||||
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
|
||||
bool wasErased(void *ptr) const { return erased.contains(ptr); }
|
||||
|
||||
void notifyOperationErased(Operation *op) override { erased.insert(op); }
|
||||
void notifyOperationErased(Operation *op) override {
|
||||
erased.insert(op);
|
||||
if (opErasedCallback)
|
||||
opErasedCallback(op);
|
||||
}
|
||||
|
||||
void notifyBlockErased(Block *block) override { erased.insert(block); }
|
||||
|
||||
private:
|
||||
/// Pointers to all erased operations and blocks.
|
||||
DenseSet<void *> erased;
|
||||
|
||||
/// A callback that is invoked when an operation is erased.
|
||||
std::function<void(Operation *)> opErasedCallback;
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
/// MLIR context.
|
||||
MLIRContext *context;
|
||||
|
||||
/// A rewriter that keeps track of ops/block that were already erased and
|
||||
/// skips duplicate op/block erasures. This rewriter is used during the
|
||||
/// "cleanup" phase.
|
||||
SingleEraseRewriter eraseRewriter;
|
||||
|
||||
// Mapping between replaced values that differ in type. This happens when
|
||||
// replacing a value with one of a different type.
|
||||
ConversionValueMapping mapping;
|
||||
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
|
||||
rewrites[i]->commit(rewriter);
|
||||
|
||||
// Clean up all rewrites.
|
||||
SingleEraseRewriter eraseRewriter(
|
||||
context, /*opErasedCallback=*/[&](Operation *op) {
|
||||
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
|
||||
unresolvedMaterializations.erase(castOp);
|
||||
});
|
||||
for (auto &rewrite : rewrites)
|
||||
rewrite->cleanup(eraseRewriter);
|
||||
}
|
||||
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
|
||||
SmallVector<UnrealizedConversionCastOp> allCastOps;
|
||||
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
|
||||
&materializations = rewriterImpl.unresolvedMaterializations;
|
||||
for (auto it : materializations) {
|
||||
if (rewriterImpl.eraseRewriter.wasErased(it.first))
|
||||
continue;
|
||||
for (auto it : materializations)
|
||||
allCastOps.push_back(it.first);
|
||||
}
|
||||
|
||||
// Reconcile all UnrealizedConversionCastOps that were inserted by the
|
||||
// dialect conversion frameworks. (Not the one that were inserted by
|
||||
|
||||
Reference in New Issue
Block a user