[mlir][Transforms] Add 1:N support to replaceUsesOfBlockArgument (#145171)
This commit adds 1:N support to `ConversionPatternRewriter::replaceUsesOfBlockArgument`. This was one of the few remaining dialect conversion APIs that does not support 1:N conversions yet. This commit also reuses `replaceUsesOfBlockArgument` in the implementation of `applySignatureConversion`. This is in preparation of the One-Shot Dialect Conversion refactoring. The goal is to bring the `applySignatureConversion` implementation into a state where it works both with and without rollbacks. To that end, `applySignatureConversion` should not directly access the `mapping`.
This commit is contained in:
committed by
GitHub
parent
2545d6f723
commit
b1b8f67eab
@@ -763,8 +763,9 @@ public:
|
||||
Region *region, const TypeConverter &converter,
|
||||
TypeConverter::SignatureConversion *entryConversion = nullptr);
|
||||
|
||||
/// Replace all the uses of the block argument `from` with value `to`.
|
||||
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
|
||||
/// Replace all the uses of the block argument `from` with `to`. This
|
||||
/// function supports both 1:1 and 1:N replacements.
|
||||
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
|
||||
|
||||
/// Return the converted value of 'key' with a type defined by the type
|
||||
/// converter of the currently executing pattern. Return nullptr in the case
|
||||
|
||||
@@ -294,7 +294,7 @@ static void restoreByValRefArgumentType(
|
||||
Type resTy = typeConverter.convertType(
|
||||
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
|
||||
|
||||
auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
|
||||
Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
|
||||
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -948,6 +948,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
/// uses.
|
||||
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
|
||||
|
||||
/// Replace the given block argument with the given values. The specified
|
||||
/// converter is used to build materializations (if necessary).
|
||||
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
|
||||
const TypeConverter *converter);
|
||||
|
||||
/// Erase the given block and its contents.
|
||||
void eraseBlock(Block *block);
|
||||
|
||||
@@ -1434,12 +1439,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
|
||||
if (!inputMap) {
|
||||
// This block argument was dropped and no replacement value was provided.
|
||||
// Materialize a replacement value "out of thin air".
|
||||
buildUnresolvedMaterialization(
|
||||
MaterializationKind::Source,
|
||||
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
|
||||
/*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
|
||||
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
|
||||
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
|
||||
Value mat =
|
||||
buildUnresolvedMaterialization(
|
||||
MaterializationKind::Source,
|
||||
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
|
||||
origArg.getLoc(),
|
||||
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
|
||||
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
|
||||
.front();
|
||||
replaceUsesOfBlockArgument(origArg, mat, converter);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1448,17 +1456,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
|
||||
assert(inputMap->size == 0 &&
|
||||
"invalid to provide a replacement value when the argument isn't "
|
||||
"dropped");
|
||||
mapping.map(origArg, inputMap->replacementValues);
|
||||
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
|
||||
replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
|
||||
converter);
|
||||
continue;
|
||||
}
|
||||
|
||||
// This is a 1->1+ mapping.
|
||||
auto replArgs =
|
||||
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
|
||||
ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
|
||||
mapping.map(origArg, std::move(replArgVals));
|
||||
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
|
||||
replaceUsesOfBlockArgument(origArg, replArgs, converter);
|
||||
}
|
||||
|
||||
appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
|
||||
@@ -1612,6 +1618,12 @@ void ConversionPatternRewriterImpl::replaceOp(
|
||||
op->walk([&](Operation *op) { replacedOps.insert(op); });
|
||||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
|
||||
BlockArgument from, ValueRange to, const TypeConverter *converter) {
|
||||
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
|
||||
mapping.map(from, to);
|
||||
}
|
||||
|
||||
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
|
||||
assert(!wasOpReplaced(block->getParentOp()) &&
|
||||
"attempting to erase a block within a replaced/erased op");
|
||||
@@ -1744,7 +1756,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
|
||||
Value to) {
|
||||
ValueRange to) {
|
||||
LLVM_DEBUG({
|
||||
impl->logger.startLine() << "** Replace Argument : '" << from << "'";
|
||||
if (Operation *parentOp = from.getOwner()->getParentOp()) {
|
||||
@@ -1754,9 +1766,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
|
||||
impl->logger.getOStream() << " (unlinked block)\n";
|
||||
}
|
||||
});
|
||||
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
|
||||
impl->currentTypeConverter);
|
||||
impl->mapping.map(from, to);
|
||||
impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
|
||||
}
|
||||
|
||||
Value ConversionPatternRewriter::getRemappedValue(Value key) {
|
||||
|
||||
@@ -300,18 +300,35 @@ func.func @create_illegal_block() {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @undo_block_arg_replace
|
||||
// expected-remark@+1{{applyPartialConversion failed}}
|
||||
module {
|
||||
func.func @undo_block_arg_replace() {
|
||||
// expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}}
|
||||
"test.undo_block_arg_replace"() ({
|
||||
^bb0(%arg0: i32):
|
||||
// CHECK: ^bb0(%[[ARG:.*]]: i32):
|
||||
// CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
|
||||
// expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
|
||||
"test.block_arg_replace"() ({
|
||||
^bb0(%arg0: i32, %arg1: i16):
|
||||
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
|
||||
// CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
|
||||
|
||||
"test.return"(%arg0) : (i32) -> ()
|
||||
}) : () -> ()
|
||||
// expected-remark@+1 {{op 'func.return' is not legalizable}}
|
||||
}) {trigger_rollback} : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @replace_block_arg_1_to_n
|
||||
func.func @replace_block_arg_1_to_n() {
|
||||
// CHECK: "test.block_arg_replace"
|
||||
"test.block_arg_replace"() ({
|
||||
^bb0(%arg0: i32, %arg1: i16):
|
||||
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
|
||||
// CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
|
||||
// CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
|
||||
"test.return"(%arg0) : (i32) -> ()
|
||||
}) : () -> ()
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
/// A simple pattern that tests the undo mechanism when replacing the uses of a
|
||||
/// block argument.
|
||||
struct TestUndoBlockArgReplace : public ConversionPattern {
|
||||
TestUndoBlockArgReplace(MLIRContext *ctx)
|
||||
: ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
|
||||
/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
|
||||
struct TestBlockArgReplace : public ConversionPattern {
|
||||
TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
|
||||
: ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
|
||||
ctx) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto illegalOp =
|
||||
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
|
||||
// Replace the first block argument with 2x the second block argument.
|
||||
Value repl = op->getRegion(0).getArgument(1);
|
||||
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
|
||||
illegalOp->getResult(0));
|
||||
rewriter.modifyOpInPlace(op, [] {});
|
||||
{repl, repl});
|
||||
rewriter.modifyOpInPlace(op, [&] {
|
||||
// If the "trigger_rollback" attribute is set, keep the op illegal, so
|
||||
// that a rollback is triggered.
|
||||
if (!op->hasAttr("trigger_rollback"))
|
||||
op->setAttr("is_legal", rewriter.getUnitAttr());
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver
|
||||
TestTypeConverter converter;
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
populateWithGenerated(patterns);
|
||||
patterns
|
||||
.add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
|
||||
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
|
||||
TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
|
||||
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
|
||||
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
|
||||
TestNonRootReplacement, TestBoundedRecursiveRewrite,
|
||||
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
|
||||
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
|
||||
TestUndoPropertiesModification, TestEraseOp,
|
||||
TestRepetitive1ToNConsumer>(&getContext());
|
||||
patterns.add<
|
||||
TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
|
||||
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
|
||||
TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
|
||||
TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
|
||||
TestUpdateConsumerType, TestNonRootReplacement,
|
||||
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
|
||||
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
|
||||
TestUndoPropertiesModification, TestEraseOp,
|
||||
TestRepetitive1ToNConsumer>(&getContext());
|
||||
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
|
||||
TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
|
||||
&getContext(), converter);
|
||||
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
|
||||
TestBlockArgReplace>(&getContext(), converter);
|
||||
patterns.add<TestConvertBlockArgs>(converter, &getContext());
|
||||
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
|
||||
converter);
|
||||
@@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver
|
||||
});
|
||||
target.addDynamicallyLegalOp<func::CallOp>(
|
||||
[&](func::CallOp op) { return converter.isLegal(op); });
|
||||
target.addDynamicallyLegalOp(
|
||||
OperationName("test.block_arg_replace", &getContext()),
|
||||
[](Operation *op) { return op->hasAttr("is_legal"); });
|
||||
|
||||
// TestCreateUnregisteredOp creates `arith.constant` operation,
|
||||
// which was not added to target intentionally to test
|
||||
|
||||
Reference in New Issue
Block a user