[mlir][cf] Preserve branch weights during cf.cond_br canonicalization. (#144822)
This commit is contained in:
@@ -153,17 +153,25 @@ def CondBranchOp
|
||||
let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"ValueRange":$trueOperands,
|
||||
"Block *":$falseDest,
|
||||
"ValueRange":$falseOperands),
|
||||
"ValueRange":$falseOperands,
|
||||
CArg<"ArrayRef<int32_t>", "{}">:$branchWeights),
|
||||
[{
|
||||
build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
|
||||
falseDest);
|
||||
DenseI32ArrayAttr weights;
|
||||
if (!branchWeights.empty())
|
||||
weights = $_builder.getDenseI32ArrayAttr(branchWeights);
|
||||
build($_builder, $_state, condition, trueOperands, falseOperands,
|
||||
weights, trueDest, falseDest);
|
||||
}]>,
|
||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"Block *":$falseDest,
|
||||
CArg<"ValueRange", "{}">:$falseOperands),
|
||||
CArg<"ValueRange", "{}">:$falseOperands,
|
||||
CArg<"ArrayRef<int32_t>", "{}">:$branchWeights),
|
||||
[{
|
||||
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
|
||||
falseOperands);
|
||||
DenseI32ArrayAttr weights;
|
||||
if (!branchWeights.empty())
|
||||
weights = $_builder.getDenseI32ArrayAttr(branchWeights);
|
||||
build($_builder, $_state, condition, ValueRange(), falseOperands,
|
||||
weights, trueDest, falseDest);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
||||
@@ -265,9 +265,9 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
|
||||
return failure();
|
||||
|
||||
// Create a new branch with the collapsed successors.
|
||||
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
|
||||
trueDest, trueDestOperands,
|
||||
falseDest, falseDestOperands);
|
||||
rewriter.replaceOpWithNewOp<CondBranchOp>(
|
||||
condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
|
||||
falseDestOperands, condbr.getWeights());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -102,6 +102,31 @@ func.func @cond_br_and_br_folding(%a : i32) {
|
||||
|
||||
/// Test that pass-through successors of CondBranchOp get folded.
|
||||
|
||||
// Test that the weights are preserved:
|
||||
// CHECK-LABEL: func.func @cond_br_passthrough_weights(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i1) -> i32 {
|
||||
func.func @cond_br_passthrough_weights(%arg0 : i32, %arg1 : i32, %cond : i1) -> i32 {
|
||||
// CHECK: cf.cond_br %[[ARG2]] weights([30, 70]), ^bb1, ^bb2
|
||||
// CHECK: ^bb1:
|
||||
// CHECK: return %[[ARG0]] : i32
|
||||
// CHECK: ^bb2:
|
||||
// CHECK: return %[[ARG1]] : i32
|
||||
// CHECK: }
|
||||
cf.cond_br %cond weights([30,70]), ^bb1, ^bb3
|
||||
|
||||
^bb1:
|
||||
cf.br ^bb2
|
||||
|
||||
^bb3:
|
||||
cf.br ^bb4
|
||||
|
||||
^bb2:
|
||||
return %arg0 : i32
|
||||
|
||||
^bb4:
|
||||
return %arg1 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cond_br_passthrough(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
|
||||
func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
|
||||
|
||||
Reference in New Issue
Block a user