From d40bab359c408b0084cd3c115213205050401a9e Mon Sep 17 00:00:00 2001 From: donald chen Date: Wed, 2 Apr 2025 11:56:13 +0800 Subject: [PATCH] [mlir][liveness] fix bugs in liveness analysis (#133416) This patch fixes the following bugs: - In SparseBackwardAnalysis, the setToExitState function should propagate changes if it modifies the lattice. Previously, this issue was masked because multi-block scenarios were not tested, and the traversal order of backward data flow analysis starts from the end of the program. - The method in liveness analysis for determining whether the non-forwarded operand in branch/region branch operations is live is incorrect, which may cause originally live variables to be marked as not live. --- .../mlir/Analysis/DataFlow/SparseAnalysis.h | 6 +- .../Analysis/DataFlow/LivenessAnalysis.cpp | 93 ++++++++++++++----- .../DataFlow/test-liveness-analysis.mlir | 37 +++++++- 3 files changed, 107 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index b9cb549a0e43..1b2c67917610 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -413,10 +413,12 @@ protected: // Visit operands on call instructions that are not forwarded. virtual void visitCallOperand(OpOperand &operand) = 0; - /// Set the given lattice element(s) at control flow exit point(s). + /// Set the given lattice element(s) at control flow exit point(s) and + /// propagate the update if it chaned. virtual void setToExitState(AbstractSparseLattice *lattice) = 0; - /// Set the given lattice element(s) at control flow exit point(s). + /// Set the given lattice element(s) at control flow exit point(s) and + /// propagate the update if it chaned. void setAllToExitStates(ArrayRef lattices); /// Get the lattice element for a value. diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index 9fb4d9df2530..c12149a1a024 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -59,7 +59,9 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) { /// (1.a) is an operand of an op with memory effects OR /// (1.b) is a non-forwarded branch operand and its branch op could take the /// control to a block that has an op with memory effects OR -/// (1.c) is a non-forwarded call operand. +/// (1.c) is a non-forwarded branch operand and its branch op could result +/// in different live result OR +/// (1.d) is a non-forwarded call operand. /// /// A value `A` is said to be "used to compute" value `B` iff `B` cannot be /// computed in the absence of `A`. Thus, in this implementation, we say that @@ -106,51 +108,88 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { // the forwarded branch operands or the non-branch operands. Thus they need // to be handled separately. This is where we handle them. - // This marks values of type (1.b) liveness as "live". A non-forwarded + // This marks values of type (1.b/1.c) liveness as "live". A non-forwarded // branch operand will be live if a block where its op could take the control - // has an op with memory effects. + // has an op with memory effects or could result in different results. // Populating such blocks in `blocks`. + bool mayLive = false; SmallVector blocks; if (isa(op)) { - // When the op is a `RegionBranchOpInterface`, like an `scf.for` or an - // `scf.index_switch` op, its branch operand controls the flow into this - // op's regions. - for (Region ®ion : op->getRegions()) { - for (Block &block : region) - blocks.push_back(&block); + if (op->getNumResults() != 0) { + // This mark value of type 1.c liveness as may live, because the region + // branch operation has a return value, and the non-forwarded operand can + // determine the region to jump to, it can thereby control the result of + // the region branch operation. + // Therefore, if the result value is live, we conservatively consider the + // non-forwarded operand of the region branch operation with result may + // live and record all result. + for (Value result : op->getResults()) { + if (getLatticeElement(result)->isLive) { + mayLive = true; + break; + } + } + } else { + // When the op is a `RegionBranchOpInterface`, like an `scf.for` or an + // `scf.index_switch` op, its branch operand controls the flow into this + // op's regions. + for (Region ®ion : op->getRegions()) { + for (Block &block : region) + blocks.push_back(&block); + } } } else if (isa(op)) { - // When the op is a `BranchOpInterface`, like a `cf.cond_br` or a - // `cf.switch` op, its branch operand controls the flow into this op's - // successors. - blocks = op->getSuccessors(); + // We cannot track all successor blocks of the branch operation(More + // specifically, it's the successor's successor). Additionally, different + // blocks might also lead to the different block argument described in 1.c. + // Therefore, we conservatively consider the non-forwarded operand of the + // branch operation may live. + mayLive = true; } else { - // When the op is a `RegionBranchTerminatorOpInterface`, like an - // `scf.condition` op or return-like, like an `scf.yield` op, its branch - // operand controls the flow into this op's parent's (which is a - // `RegionBranchOpInterface`'s) regions. Operation *parentOp = op->getParentOp(); assert(isa(parentOp) && "expected parent op to implement `RegionBranchOpInterface`"); - for (Region ®ion : parentOp->getRegions()) { - for (Block &block : region) - blocks.push_back(&block); + if (parentOp->getNumResults() != 0) { + // This mark value of type 1.c liveness as may live, because the region + // branch operation has a return value, and the non-forwarded operand can + // determine the region to jump to, it can thereby control the result of + // the region branch operation. + // Therefore, if the result value is live, we conservatively consider the + // non-forwarded operand of the region branch operation with result may + // live and record all result. + for (Value result : parentOp->getResults()) { + if (getLatticeElement(result)->isLive) { + mayLive = true; + break; + } + } + } else { + // When the op is a `RegionBranchTerminatorOpInterface`, like an + // `scf.condition` op or return-like, like an `scf.yield` op, its branch + // operand controls the flow into this op's parent's (which is a + // `RegionBranchOpInterface`'s) regions. + for (Region ®ion : parentOp->getRegions()) { + for (Block &block : region) + blocks.push_back(&block); + } } } - bool foundMemoryEffectingOp = false; for (Block *block : blocks) { - if (foundMemoryEffectingOp) + if (mayLive) break; for (Operation &nestedOp : *block) { if (!isMemoryEffectFree(&nestedOp)) { - Liveness *operandLiveness = getLatticeElement(operand.get()); - propagateIfChanged(operandLiveness, operandLiveness->markLive()); - foundMemoryEffectingOp = true; + mayLive = true; break; } } } + if (mayLive) { + Liveness *operandLiveness = getLatticeElement(operand.get()); + propagateIfChanged(operandLiveness, operandLiveness->markLive()); + } + // Now that we have checked for memory-effecting ops in the blocks of concern, // we will simply visit the op with this non-forwarded operand to potentially // mark it "live" due to type (1.a/3) liveness. @@ -191,8 +230,12 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) { } void LivenessAnalysis::setToExitState(Liveness *lattice) { + if (lattice->isLive) { + return; + } // This marks values of type (2) liveness as "live". (void)lattice->markLive(); + propagateIfChanged(lattice, ChangeResult::Change); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir index b6aed1c0b054..a89a0f4084e9 100644 --- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir +++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir @@ -59,16 +59,49 @@ func.func @test_3_BranchOpInterface_type_1.b(%arg0: i32, %arg1: memref, %ar // ----- +// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch +// op could result in different result" +// CHECK-LABEL: test_tag: cond_br: +// CHECK-NEXT: operand #0: live +// CHECK-NEXT: operand #1: live +// CHECK-NEXT: operand #2: live +func.func @test_branch_result_in_different_result_1.c(%arg0 : tensor, %arg1 : tensor, %arg2 : i1) -> tensor { + cf.cond_br %arg2, ^bb1(%arg0 : tensor), ^bb2(%arg1 : tensor) {tag = "cond_br"} +^bb1(%0 : tensor): + cf.br ^bb3(%0 : tensor) +^bb2(%1 : tensor): + cf.br ^bb3(%1 : tensor) +^bb3(%2 : tensor): + return %2 : tensor +} + +// ----- + +// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch +// op could result in different result" +// CHECK-LABEL: test_tag: region_branch: +// CHECK-NEXT: operand #0: live +func.func @test_region_branch_result_in_different_result_1.c(%arg0 : tensor, %arg1 : tensor, %arg2 : i1) -> tensor { + %0 = scf.if %arg2 -> tensor { + scf.yield %arg0 : tensor + } else { + scf.yield %arg1 : tensor + } {tag="region_branch"} + return %0 : tensor +} + +// ----- + func.func private @private(%arg0 : i32, %arg1 : i32) { func.return } -// Positive test: Type (1.c) "is a non-forwarded call operand" +// Positive test: Type (1.d) "is a non-forwarded call operand" // CHECK-LABEL: test_tag: call // CHECK-LABEL: operand #0: not live // CHECK-LABEL: operand #1: not live // CHECK-LABEL: operand #2: live -func.func @test_4_type_1.c(%arg0: i32, %arg1: i32, %device: i32, %m0: memref) { +func.func @test_4_type_1.d(%arg0: i32, %arg1: i32, %device: i32, %m0: memref) { test.call_on_device @private(%arg0, %arg1), %device {tag = "call"} : (i32, i32, i32) -> () return }