[mlir][liveness] fix bugs in liveness analysis (#133416)

This patch fixes the following bugs:
- In SparseBackwardAnalysis, the setToExitState function should
propagate changes if it modifies the lattice. Previously, this issue was
masked because multi-block scenarios were not tested, and the traversal
order of backward data flow analysis starts from the end of the program.
- The method in liveness analysis for determining whether the
non-forwarded operand in branch/region branch operations is live is
incorrect, which may cause originally live variables to be marked as not
live.
This commit is contained in:
donald chen
2025-04-02 11:56:13 +08:00
committed by GitHub
parent 03a791f703
commit d40bab359c
3 changed files with 107 additions and 29 deletions

View File

@@ -413,10 +413,12 @@ protected:
// Visit operands on call instructions that are not forwarded.
virtual void visitCallOperand(OpOperand &operand) = 0;
/// Set the given lattice element(s) at control flow exit point(s).
/// Set the given lattice element(s) at control flow exit point(s) and
/// propagate the update if it chaned.
virtual void setToExitState(AbstractSparseLattice *lattice) = 0;
/// Set the given lattice element(s) at control flow exit point(s).
/// Set the given lattice element(s) at control flow exit point(s) and
/// propagate the update if it chaned.
void setAllToExitStates(ArrayRef<AbstractSparseLattice *> lattices);
/// Get the lattice element for a value.

View File

@@ -59,7 +59,9 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
/// (1.a) is an operand of an op with memory effects OR
/// (1.b) is a non-forwarded branch operand and its branch op could take the
/// control to a block that has an op with memory effects OR
/// (1.c) is a non-forwarded call operand.
/// (1.c) is a non-forwarded branch operand and its branch op could result
/// in different live result OR
/// (1.d) is a non-forwarded call operand.
///
/// A value `A` is said to be "used to compute" value `B` iff `B` cannot be
/// computed in the absence of `A`. Thus, in this implementation, we say that
@@ -106,12 +108,28 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
// the forwarded branch operands or the non-branch operands. Thus they need
// to be handled separately. This is where we handle them.
// This marks values of type (1.b) liveness as "live". A non-forwarded
// This marks values of type (1.b/1.c) liveness as "live". A non-forwarded
// branch operand will be live if a block where its op could take the control
// has an op with memory effects.
// has an op with memory effects or could result in different results.
// Populating such blocks in `blocks`.
bool mayLive = false;
SmallVector<Block *, 4> blocks;
if (isa<RegionBranchOpInterface>(op)) {
if (op->getNumResults() != 0) {
// This mark value of type 1.c liveness as may live, because the region
// branch operation has a return value, and the non-forwarded operand can
// determine the region to jump to, it can thereby control the result of
// the region branch operation.
// Therefore, if the result value is live, we conservatively consider the
// non-forwarded operand of the region branch operation with result may
// live and record all result.
for (Value result : op->getResults()) {
if (getLatticeElement(result)->isLive) {
mayLive = true;
break;
}
}
} else {
// When the op is a `RegionBranchOpInterface`, like an `scf.for` or an
// `scf.index_switch` op, its branch operand controls the flow into this
// op's regions.
@@ -119,38 +137,59 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
for (Block &block : region)
blocks.push_back(&block);
}
}
} else if (isa<BranchOpInterface>(op)) {
// When the op is a `BranchOpInterface`, like a `cf.cond_br` or a
// `cf.switch` op, its branch operand controls the flow into this op's
// successors.
blocks = op->getSuccessors();
// We cannot track all successor blocks of the branch operation(More
// specifically, it's the successor's successor). Additionally, different
// blocks might also lead to the different block argument described in 1.c.
// Therefore, we conservatively consider the non-forwarded operand of the
// branch operation may live.
mayLive = true;
} else {
Operation *parentOp = op->getParentOp();
assert(isa<RegionBranchOpInterface>(parentOp) &&
"expected parent op to implement `RegionBranchOpInterface`");
if (parentOp->getNumResults() != 0) {
// This mark value of type 1.c liveness as may live, because the region
// branch operation has a return value, and the non-forwarded operand can
// determine the region to jump to, it can thereby control the result of
// the region branch operation.
// Therefore, if the result value is live, we conservatively consider the
// non-forwarded operand of the region branch operation with result may
// live and record all result.
for (Value result : parentOp->getResults()) {
if (getLatticeElement(result)->isLive) {
mayLive = true;
break;
}
}
} else {
// When the op is a `RegionBranchTerminatorOpInterface`, like an
// `scf.condition` op or return-like, like an `scf.yield` op, its branch
// operand controls the flow into this op's parent's (which is a
// `RegionBranchOpInterface`'s) regions.
Operation *parentOp = op->getParentOp();
assert(isa<RegionBranchOpInterface>(parentOp) &&
"expected parent op to implement `RegionBranchOpInterface`");
for (Region &region : parentOp->getRegions()) {
for (Block &block : region)
blocks.push_back(&block);
}
}
bool foundMemoryEffectingOp = false;
}
for (Block *block : blocks) {
if (foundMemoryEffectingOp)
if (mayLive)
break;
for (Operation &nestedOp : *block) {
if (!isMemoryEffectFree(&nestedOp)) {
Liveness *operandLiveness = getLatticeElement(operand.get());
propagateIfChanged(operandLiveness, operandLiveness->markLive());
foundMemoryEffectingOp = true;
mayLive = true;
break;
}
}
}
if (mayLive) {
Liveness *operandLiveness = getLatticeElement(operand.get());
propagateIfChanged(operandLiveness, operandLiveness->markLive());
}
// Now that we have checked for memory-effecting ops in the blocks of concern,
// we will simply visit the op with this non-forwarded operand to potentially
// mark it "live" due to type (1.a/3) liveness.
@@ -191,8 +230,12 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
}
void LivenessAnalysis::setToExitState(Liveness *lattice) {
if (lattice->isLive) {
return;
}
// This marks values of type (2) liveness as "live".
(void)lattice->markLive();
propagateIfChanged(lattice, ChangeResult::Change);
}
//===----------------------------------------------------------------------===//

View File

@@ -59,16 +59,49 @@ func.func @test_3_BranchOpInterface_type_1.b(%arg0: i32, %arg1: memref<i32>, %ar
// -----
// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch
// op could result in different result"
// CHECK-LABEL: test_tag: cond_br:
// CHECK-NEXT: operand #0: live
// CHECK-NEXT: operand #1: live
// CHECK-NEXT: operand #2: live
func.func @test_branch_result_in_different_result_1.c(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : i1) -> tensor<f32> {
cf.cond_br %arg2, ^bb1(%arg0 : tensor<f32>), ^bb2(%arg1 : tensor<f32>) {tag = "cond_br"}
^bb1(%0 : tensor<f32>):
cf.br ^bb3(%0 : tensor<f32>)
^bb2(%1 : tensor<f32>):
cf.br ^bb3(%1 : tensor<f32>)
^bb3(%2 : tensor<f32>):
return %2 : tensor<f32>
}
// -----
// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch
// op could result in different result"
// CHECK-LABEL: test_tag: region_branch:
// CHECK-NEXT: operand #0: live
func.func @test_region_branch_result_in_different_result_1.c(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : i1) -> tensor<f32> {
%0 = scf.if %arg2 -> tensor<f32> {
scf.yield %arg0 : tensor<f32>
} else {
scf.yield %arg1 : tensor<f32>
} {tag="region_branch"}
return %0 : tensor<f32>
}
// -----
func.func private @private(%arg0 : i32, %arg1 : i32) {
func.return
}
// Positive test: Type (1.c) "is a non-forwarded call operand"
// Positive test: Type (1.d) "is a non-forwarded call operand"
// CHECK-LABEL: test_tag: call
// CHECK-LABEL: operand #0: not live
// CHECK-LABEL: operand #1: not live
// CHECK-LABEL: operand #2: live
func.func @test_4_type_1.c(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
func.func @test_4_type_1.d(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
test.call_on_device @private(%arg0, %arg1), %device {tag = "call"} : (i32, i32, i32) -> ()
return
}