[mlir][liveness] fix bugs in liveness analysis (#133416)
This patch fixes the following bugs: - In SparseBackwardAnalysis, the setToExitState function should propagate changes if it modifies the lattice. Previously, this issue was masked because multi-block scenarios were not tested, and the traversal order of backward data flow analysis starts from the end of the program. - The method in liveness analysis for determining whether the non-forwarded operand in branch/region branch operations is live is incorrect, which may cause originally live variables to be marked as not live.
This commit is contained in:
@@ -413,10 +413,12 @@ protected:
|
||||
// Visit operands on call instructions that are not forwarded.
|
||||
virtual void visitCallOperand(OpOperand &operand) = 0;
|
||||
|
||||
/// Set the given lattice element(s) at control flow exit point(s).
|
||||
/// Set the given lattice element(s) at control flow exit point(s) and
|
||||
/// propagate the update if it chaned.
|
||||
virtual void setToExitState(AbstractSparseLattice *lattice) = 0;
|
||||
|
||||
/// Set the given lattice element(s) at control flow exit point(s).
|
||||
/// Set the given lattice element(s) at control flow exit point(s) and
|
||||
/// propagate the update if it chaned.
|
||||
void setAllToExitStates(ArrayRef<AbstractSparseLattice *> lattices);
|
||||
|
||||
/// Get the lattice element for a value.
|
||||
|
||||
@@ -59,7 +59,9 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
|
||||
/// (1.a) is an operand of an op with memory effects OR
|
||||
/// (1.b) is a non-forwarded branch operand and its branch op could take the
|
||||
/// control to a block that has an op with memory effects OR
|
||||
/// (1.c) is a non-forwarded call operand.
|
||||
/// (1.c) is a non-forwarded branch operand and its branch op could result
|
||||
/// in different live result OR
|
||||
/// (1.d) is a non-forwarded call operand.
|
||||
///
|
||||
/// A value `A` is said to be "used to compute" value `B` iff `B` cannot be
|
||||
/// computed in the absence of `A`. Thus, in this implementation, we say that
|
||||
@@ -106,12 +108,28 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
|
||||
// the forwarded branch operands or the non-branch operands. Thus they need
|
||||
// to be handled separately. This is where we handle them.
|
||||
|
||||
// This marks values of type (1.b) liveness as "live". A non-forwarded
|
||||
// This marks values of type (1.b/1.c) liveness as "live". A non-forwarded
|
||||
// branch operand will be live if a block where its op could take the control
|
||||
// has an op with memory effects.
|
||||
// has an op with memory effects or could result in different results.
|
||||
// Populating such blocks in `blocks`.
|
||||
bool mayLive = false;
|
||||
SmallVector<Block *, 4> blocks;
|
||||
if (isa<RegionBranchOpInterface>(op)) {
|
||||
if (op->getNumResults() != 0) {
|
||||
// This mark value of type 1.c liveness as may live, because the region
|
||||
// branch operation has a return value, and the non-forwarded operand can
|
||||
// determine the region to jump to, it can thereby control the result of
|
||||
// the region branch operation.
|
||||
// Therefore, if the result value is live, we conservatively consider the
|
||||
// non-forwarded operand of the region branch operation with result may
|
||||
// live and record all result.
|
||||
for (Value result : op->getResults()) {
|
||||
if (getLatticeElement(result)->isLive) {
|
||||
mayLive = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// When the op is a `RegionBranchOpInterface`, like an `scf.for` or an
|
||||
// `scf.index_switch` op, its branch operand controls the flow into this
|
||||
// op's regions.
|
||||
@@ -119,38 +137,59 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
|
||||
for (Block &block : region)
|
||||
blocks.push_back(&block);
|
||||
}
|
||||
}
|
||||
} else if (isa<BranchOpInterface>(op)) {
|
||||
// When the op is a `BranchOpInterface`, like a `cf.cond_br` or a
|
||||
// `cf.switch` op, its branch operand controls the flow into this op's
|
||||
// successors.
|
||||
blocks = op->getSuccessors();
|
||||
// We cannot track all successor blocks of the branch operation(More
|
||||
// specifically, it's the successor's successor). Additionally, different
|
||||
// blocks might also lead to the different block argument described in 1.c.
|
||||
// Therefore, we conservatively consider the non-forwarded operand of the
|
||||
// branch operation may live.
|
||||
mayLive = true;
|
||||
} else {
|
||||
Operation *parentOp = op->getParentOp();
|
||||
assert(isa<RegionBranchOpInterface>(parentOp) &&
|
||||
"expected parent op to implement `RegionBranchOpInterface`");
|
||||
if (parentOp->getNumResults() != 0) {
|
||||
// This mark value of type 1.c liveness as may live, because the region
|
||||
// branch operation has a return value, and the non-forwarded operand can
|
||||
// determine the region to jump to, it can thereby control the result of
|
||||
// the region branch operation.
|
||||
// Therefore, if the result value is live, we conservatively consider the
|
||||
// non-forwarded operand of the region branch operation with result may
|
||||
// live and record all result.
|
||||
for (Value result : parentOp->getResults()) {
|
||||
if (getLatticeElement(result)->isLive) {
|
||||
mayLive = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// When the op is a `RegionBranchTerminatorOpInterface`, like an
|
||||
// `scf.condition` op or return-like, like an `scf.yield` op, its branch
|
||||
// operand controls the flow into this op's parent's (which is a
|
||||
// `RegionBranchOpInterface`'s) regions.
|
||||
Operation *parentOp = op->getParentOp();
|
||||
assert(isa<RegionBranchOpInterface>(parentOp) &&
|
||||
"expected parent op to implement `RegionBranchOpInterface`");
|
||||
for (Region ®ion : parentOp->getRegions()) {
|
||||
for (Block &block : region)
|
||||
blocks.push_back(&block);
|
||||
}
|
||||
}
|
||||
bool foundMemoryEffectingOp = false;
|
||||
}
|
||||
for (Block *block : blocks) {
|
||||
if (foundMemoryEffectingOp)
|
||||
if (mayLive)
|
||||
break;
|
||||
for (Operation &nestedOp : *block) {
|
||||
if (!isMemoryEffectFree(&nestedOp)) {
|
||||
Liveness *operandLiveness = getLatticeElement(operand.get());
|
||||
propagateIfChanged(operandLiveness, operandLiveness->markLive());
|
||||
foundMemoryEffectingOp = true;
|
||||
mayLive = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mayLive) {
|
||||
Liveness *operandLiveness = getLatticeElement(operand.get());
|
||||
propagateIfChanged(operandLiveness, operandLiveness->markLive());
|
||||
}
|
||||
|
||||
// Now that we have checked for memory-effecting ops in the blocks of concern,
|
||||
// we will simply visit the op with this non-forwarded operand to potentially
|
||||
// mark it "live" due to type (1.a/3) liveness.
|
||||
@@ -191,8 +230,12 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
|
||||
}
|
||||
|
||||
void LivenessAnalysis::setToExitState(Liveness *lattice) {
|
||||
if (lattice->isLive) {
|
||||
return;
|
||||
}
|
||||
// This marks values of type (2) liveness as "live".
|
||||
(void)lattice->markLive();
|
||||
propagateIfChanged(lattice, ChangeResult::Change);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -59,16 +59,49 @@ func.func @test_3_BranchOpInterface_type_1.b(%arg0: i32, %arg1: memref<i32>, %ar
|
||||
|
||||
// -----
|
||||
|
||||
// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch
|
||||
// op could result in different result"
|
||||
// CHECK-LABEL: test_tag: cond_br:
|
||||
// CHECK-NEXT: operand #0: live
|
||||
// CHECK-NEXT: operand #1: live
|
||||
// CHECK-NEXT: operand #2: live
|
||||
func.func @test_branch_result_in_different_result_1.c(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : i1) -> tensor<f32> {
|
||||
cf.cond_br %arg2, ^bb1(%arg0 : tensor<f32>), ^bb2(%arg1 : tensor<f32>) {tag = "cond_br"}
|
||||
^bb1(%0 : tensor<f32>):
|
||||
cf.br ^bb3(%0 : tensor<f32>)
|
||||
^bb2(%1 : tensor<f32>):
|
||||
cf.br ^bb3(%1 : tensor<f32>)
|
||||
^bb3(%2 : tensor<f32>):
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch
|
||||
// op could result in different result"
|
||||
// CHECK-LABEL: test_tag: region_branch:
|
||||
// CHECK-NEXT: operand #0: live
|
||||
func.func @test_region_branch_result_in_different_result_1.c(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : i1) -> tensor<f32> {
|
||||
%0 = scf.if %arg2 -> tensor<f32> {
|
||||
scf.yield %arg0 : tensor<f32>
|
||||
} else {
|
||||
scf.yield %arg1 : tensor<f32>
|
||||
} {tag="region_branch"}
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @private(%arg0 : i32, %arg1 : i32) {
|
||||
func.return
|
||||
}
|
||||
|
||||
// Positive test: Type (1.c) "is a non-forwarded call operand"
|
||||
// Positive test: Type (1.d) "is a non-forwarded call operand"
|
||||
// CHECK-LABEL: test_tag: call
|
||||
// CHECK-LABEL: operand #0: not live
|
||||
// CHECK-LABEL: operand #1: not live
|
||||
// CHECK-LABEL: operand #2: live
|
||||
func.func @test_4_type_1.c(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
|
||||
func.func @test_4_type_1.d(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
|
||||
test.call_on_device @private(%arg0, %arg1), %device {tag = "call"} : (i32, i32, i32) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user