diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index b01c8b02ec66..67df7a8af098 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7327,9 +7327,11 @@ DenseMap LoopVectorizationPlanner::executePlan( OrigLoop->getHeader()->getContext()); VPlanTransforms::runPass(VPlanTransforms::replicateByVF, BestVPlan, BestVF); VPlanTransforms::runPass(VPlanTransforms::materializeBroadcasts, BestVPlan); - if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) + if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) { + std::optional VScale = CM.getVScaleForTuning(); VPlanTransforms::runPass(VPlanTransforms::addBranchWeightToMiddleTerminator, - BestVPlan, BestVF); + BestVPlan, BestVF, VScale); + } VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE); VPlanTransforms::simplifyRecipes(BestVPlan, *Legal->getWidestInductionType()); VPlanTransforms::narrowInterleaveGroups( diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 3dfd625f83a6..8d4a73c74446 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -3330,8 +3330,8 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF, /// Add branch weight metadata, if the \p Plan's middle block is terminated by a /// BranchOnCond recipe. -void VPlanTransforms::addBranchWeightToMiddleTerminator(VPlan &Plan, - ElementCount VF) { +void VPlanTransforms::addBranchWeightToMiddleTerminator( + VPlan &Plan, ElementCount VF, std::optional VScaleForTuning) { VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock(); auto *MiddleTerm = dyn_cast_or_null(MiddleVPBB->getTerminator()); @@ -3343,6 +3343,8 @@ void VPlanTransforms::addBranchWeightToMiddleTerminator(VPlan &Plan, "must have a BranchOnCond"); // Assume that `TripCount % VectorStep ` is equally distributed. unsigned VectorStep = Plan.getUF() * VF.getKnownMinValue(); + if (VF.isScalable() && VScaleForTuning.has_value()) + VectorStep *= *VScaleForTuning; assert(VectorStep > 0 && "trip count should not be zero"); MDBuilder MDB(Plan.getScalarHeader()->getIRBasicBlock()->getContext()); MDNode *BranchWeights = diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index 40885cd52a12..8d2eded45da2 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -238,7 +238,9 @@ struct VPlanTransforms { /// Add branch weight metadata, if the \p Plan's middle block is terminated by /// a BranchOnCond recipe. - static void addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF); + static void + addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF, + std::optional VScaleForTuning); }; } // namespace llvm diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/check-prof-info.ll b/llvm/test/Transforms/LoopVectorize/AArch64/check-prof-info.ll index 9435c544fc81..1f619898ea78 100644 --- a/llvm/test/Transforms/LoopVectorize/AArch64/check-prof-info.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/check-prof-info.ll @@ -92,7 +92,7 @@ for.cond.cleanup: ; preds = %for.body ; CHECK-V1-IC1: [[LOOP1]] = distinct !{[[LOOP1]], [[META2:![0-9]+]], [[META3:![0-9]+]]} ; CHECK-V1-IC1: [[META2]] = !{!"llvm.loop.isvectorized", i32 1} ; CHECK-V1-IC1: [[META3]] = !{!"llvm.loop.unroll.runtime.disable"} -; CHECK-V1-IC1: [[PROF4]] = !{!"branch_weights", i32 1, i32 3} +; CHECK-V1-IC1: [[PROF4]] = !{!"branch_weights", i32 1, i32 7} ; CHECK-V1-IC1: [[PROF5]] = !{!"branch_weights", i32 0, i32 0} ; CHECK-V1-IC1: [[LOOP6]] = distinct !{[[LOOP6]], [[META3]], [[META2]]} ;.