[mlir][Transforms] Add 1:N support to replaceUsesOfBlockArgument (#145171)

This commit adds 1:N support to
`ConversionPatternRewriter::replaceUsesOfBlockArgument`. This was one of
the few remaining dialect conversion APIs that does not support 1:N
conversions yet.

This commit also reuses `replaceUsesOfBlockArgument` in the
implementation of `applySignatureConversion`. This is in preparation of
the One-Shot Dialect Conversion refactoring. The goal is to bring the
`applySignatureConversion` implementation into a state where it works
both with and without rollbacks. To that end, `applySignatureConversion`
should not directly access the `mapping`.
This commit is contained in:
Matthias Springer
2025-06-23 12:07:00 +02:00
committed by GitHub
parent 2545d6f723
commit b1b8f67eab
5 changed files with 82 additions and 47 deletions

View File

@@ -763,8 +763,9 @@ public:
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);
/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
/// Replace all the uses of the block argument `from` with `to`. This
/// function supports both 1:1 and 1:N replacements.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case

View File

@@ -294,7 +294,7 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
}
}

View File

@@ -948,6 +948,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// uses.
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
/// Replace the given block argument with the given values. The specified
/// converter is used to build materializations (if necessary).
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
const TypeConverter *converter);
/// Erase the given block and its contents.
void eraseBlock(Block *block);
@@ -1434,12 +1439,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
Value mat =
buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
origArg.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
.front();
replaceUsesOfBlockArgument(origArg, mat, converter);
continue;
}
@@ -1448,17 +1456,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValues);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
converter);
continue;
}
// This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
mapping.map(origArg, std::move(replArgVals));
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
replaceUsesOfBlockArgument(origArg, replArgs, converter);
}
appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
@@ -1612,6 +1618,12 @@ void ConversionPatternRewriterImpl::replaceOp(
op->walk([&](Operation *op) { replacedOps.insert(op); });
}
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
BlockArgument from, ValueRange to, const TypeConverter *converter) {
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");
@@ -1744,7 +1756,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
ValueRange to) {
LLVM_DEBUG({
impl->logger.startLine() << "** Replace Argument : '" << from << "'";
if (Operation *parentOp = from.getOwner()->getParentOp()) {
@@ -1754,9 +1766,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
impl->logger.getOStream() << " (unlinked block)\n";
}
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
impl->mapping.map(from, to);
impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {

View File

@@ -300,18 +300,35 @@ func.func @create_illegal_block() {
// -----
// CHECK-LABEL: @undo_block_arg_replace
// expected-remark@+1{{applyPartialConversion failed}}
module {
func.func @undo_block_arg_replace() {
// expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}}
"test.undo_block_arg_replace"() ({
^bb0(%arg0: i32):
// CHECK: ^bb0(%[[ARG:.*]]: i32):
// CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
// expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
"test.block_arg_replace"() ({
^bb0(%arg0: i32, %arg1: i16):
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
// CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
"test.return"(%arg0) : (i32) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
}) {trigger_rollback} : () -> ()
return
}
}
// -----
// CHECK-LABEL: @replace_block_arg_1_to_n
func.func @replace_block_arg_1_to_n() {
// CHECK: "test.block_arg_replace"
"test.block_arg_replace"() ({
^bb0(%arg0: i32, %arg1: i16):
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
// CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
// CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
"test.return"(%arg0) : (i32) -> ()
}) : () -> ()
"test.return"() : () -> ()
}
// -----

View File

@@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern {
}
};
/// A simple pattern that tests the undo mechanism when replacing the uses of a
/// block argument.
struct TestUndoBlockArgReplace : public ConversionPattern {
TestUndoBlockArgReplace(MLIRContext *ctx)
: ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
struct TestBlockArgReplace : public ConversionPattern {
TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
: ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto illegalOp =
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
// Replace the first block argument with 2x the second block argument.
Value repl = op->getRegion(0).getArgument(1);
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
illegalOp->getResult(0));
rewriter.modifyOpInPlace(op, [] {});
{repl, repl});
rewriter.modifyOpInPlace(op, [&] {
// If the "trigger_rollback" attribute is set, keep the op illegal, so
// that a rollback is triggered.
if (!op->hasAttr("trigger_rollback"))
op->setAttr("is_legal", rewriter.getUnitAttr());
});
return success();
}
};
@@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
patterns
.add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<
TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
TestUpdateConsumerType, TestNonRootReplacement,
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
&getContext(), converter);
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestBlockArgReplace>(&getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver
});
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return converter.isLegal(op); });
target.addDynamicallyLegalOp(
OperationName("test.block_arg_replace", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });
// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test