[mlir][Transforms][NFC] Dialect Conversion: Keep unresolvedMaterializations up to date (#144254)

`unresolvedMaterializations` is a mapping from
`UnrealizedConversionCastOp` to `UnresolvedMaterializationRewrite`. This
mapping is needed to find the correct type converter for an unresolved
materialization.

With this commit, `unresolvedMaterializations` is updated immediately
when an op is being erased. This also cleans up the code base a bit:
`SingleEraseRewriter` is now used only during the "cleanup" phase and no
longer needed as a field of `ConversionRewriterImpl`.

This commit is in preparation of the One-Shot Dialect Conversion
refactoring: `allowPatternRollback = false` will in the future trigger
immediate materialization of all IR changes.
This commit is contained in:
Matthias Springer
2025-06-18 14:42:09 +02:00
committed by GitHub
parent a1c2a71293
commit 66580f77b8

View File

@@ -848,7 +848,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
: context(ctx), eraseRewriter(ctx), config(config) {}
: context(ctx), config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// no new IR is created between calls to `eraseOp`/`eraseBlock`.
struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
public:
SingleEraseRewriter(MLIRContext *context)
: RewriterBase(context, /*listener=*/this) {}
SingleEraseRewriter(
MLIRContext *context,
std::function<void(Operation *)> opErasedCallback = nullptr)
: RewriterBase(context, /*listener=*/this),
opErasedCallback(opErasedCallback) {}
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasErased(void *ptr) const { return erased.contains(ptr); }
void notifyOperationErased(Operation *op) override { erased.insert(op); }
void notifyOperationErased(Operation *op) override {
erased.insert(op);
if (opErasedCallback)
opErasedCallback(op);
}
void notifyBlockErased(Block *block) override { erased.insert(block); }
private:
/// Pointers to all erased operations and blocks.
DenseSet<void *> erased;
/// A callback that is invoked when an operation is erased.
std::function<void(Operation *)> opErasedCallback;
};
//===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// MLIR context.
MLIRContext *context;
/// A rewriter that keeps track of ops/block that were already erased and
/// skips duplicate op/block erasures. This rewriter is used during the
/// "cleanup" phase.
SingleEraseRewriter eraseRewriter;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
rewrites[i]->commit(rewriter);
// Clean up all rewrites.
SingleEraseRewriter eraseRewriter(
context, /*opErasedCallback=*/[&](Operation *op) {
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
unresolvedMaterializations.erase(castOp);
});
for (auto &rewrite : rewrites)
rewrite->cleanup(eraseRewriter);
}
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
SmallVector<UnrealizedConversionCastOp> allCastOps;
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
&materializations = rewriterImpl.unresolvedMaterializations;
for (auto it : materializations) {
if (rewriterImpl.eraseRewriter.wasErased(it.first))
continue;
for (auto it : materializations)
allCastOps.push_back(it.first);
}
// Reconcile all UnrealizedConversionCastOps that were inserted by the
// dialect conversion frameworks. (Not the one that were inserted by