[SLP]Fix perfect diamond match with extractelements in scalars

Need to drop all previous estimations/vectorizations, when found
a perfect diamond match. This improves cost estimation and improves code
emission.
Also, need to adjust getScalarizationOverhead cost for non-poison input
vector. Currently, it does not allow to estimate it correctly, so
instead use conservative element-by-element insertelement cost for each
unique scalar.

Reviewers: RKSimon, hiraditya

Reviewed By: RKSimon

Pull Request: https://github.com/llvm/llvm-project/pull/132466
This commit is contained in:
Alexey Bataev
2025-03-24 09:29:18 -04:00
committed by GitHub
parent 03d8529d01
commit ad9909dd73
3 changed files with 69 additions and 62 deletions

View File

@@ -5310,12 +5310,11 @@ getShuffleCost(const TargetTransformInfo &TTI, TTI::ShuffleKind Kind,
/// This is similar to TargetTransformInfo::getScalarizationOverhead, but if
/// ScalarTy is a FixedVectorType, a vector will be inserted or extracted
/// instead of a scalar.
static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
Type *ScalarTy, VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {}) {
static InstructionCost
getScalarizationOverhead(const TargetTransformInfo &TTI, Type *ScalarTy,
VectorType *Ty, const APInt &DemandedElts, bool Insert,
bool Extract, TTI::TargetCostKind CostKind,
bool ForPoisonSrc = true, ArrayRef<Value *> VL = {}) {
assert(!isa<ScalableVectorType>(Ty) &&
"ScalableVectorType is not supported.");
assert(getNumElements(ScalarTy) * DemandedElts.getBitWidth() ==
@@ -5339,8 +5338,26 @@ static InstructionCost getScalarizationOverhead(const TargetTransformInfo &TTI,
}
return Cost;
}
return TTI.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
CostKind, VL);
APInt NewDemandedElts = DemandedElts;
InstructionCost Cost = 0;
if (!ForPoisonSrc && Insert) {
// Handle insert into non-poison vector.
// TODO: Need to teach getScalarizationOverhead about insert elements into
// non-poison input vector to better handle such cases. Currently, it is
// very conservative and may "pessimize" the vectorization.
for (unsigned I : seq(DemandedElts.getBitWidth())) {
if (!DemandedElts[I])
continue;
Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, CostKind,
I, Constant::getNullValue(Ty),
VL.empty() ? nullptr : VL[I]);
}
NewDemandedElts.clearAllBits();
} else if (!NewDemandedElts.isZero()) {
Cost += TTI.getScalarizationOverhead(Ty, NewDemandedElts, Insert, Extract,
CostKind, VL);
}
return Cost;
}
/// Correctly creates insert_subvector, checking that the index is multiple of
@@ -11684,6 +11701,15 @@ public:
// No need to delay the cost estimation during analysis.
return std::nullopt;
}
/// Reset the builder to handle perfect diamond match.
void resetForSameNode() {
IsFinalized = false;
CommonMask.clear();
InVectors.clear();
Cost = 0;
VectorizedVals.clear();
SameNodesEstimated = true;
}
void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
if (&E1 == &E2) {
assert(all_of(Mask,
@@ -14890,15 +14916,18 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
ShuffledElements.setBit(I);
ShuffleMask[I] = Res.first->second;
}
if (!DemandedElements.isZero())
Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
/*Insert=*/true,
/*Extract=*/false, CostKind, VL);
if (ForPoisonSrc)
if (ForPoisonSrc) {
Cost = getScalarizationOverhead(*TTI, ScalarTy, VecTy,
/*DemandedElts*/ ~ShuffledElements,
/*Insert*/ true,
/*Extract*/ false, CostKind, VL);
/*Extract*/ false, CostKind,
/*ForPoisonSrc=*/true, VL);
} else if (!DemandedElements.isZero()) {
Cost += getScalarizationOverhead(*TTI, ScalarTy, VecTy, DemandedElements,
/*Insert=*/true,
/*Extract=*/false, CostKind,
/*ForPoisonSrc=*/false, VL);
}
if (DuplicateNonConst)
Cost += ::getShuffleCost(*TTI, TargetTransformInfo::SK_PermuteSingleSrc,
VecTy, ShuffleMask);
@@ -15556,6 +15585,12 @@ public:
PoisonValue::get(PointerType::getUnqual(ScalarTy->getContext())),
MaybeAlign());
}
/// Reset the builder to handle perfect diamond match.
void resetForSameNode() {
IsFinalized = false;
CommonMask.clear();
InVectors.clear();
}
/// Adds 2 input vectors (in form of tree entries) and the mask for their
/// shuffling.
void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
@@ -16111,6 +16146,9 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
Mask[I] = FrontTE->findLaneForValue(V);
}
}
// Reset the builder(s) to correctly handle perfect diamond matched
// nodes.
ShuffleBuilder.resetForSameNode();
ShuffleBuilder.add(*FrontTE, Mask);
// Full matched entry found, no need to insert subvectors.
Res = ShuffleBuilder.finalize(E->getCommonMask(), {}, {});

View File

@@ -10,18 +10,15 @@ define <4 x double> @test(ptr %ia, ptr %ib, ptr %ic, ptr %id, ptr %ie, ptr %x) {
; CHECK-NEXT: [[I4275:%.*]] = load double, ptr [[ID]], align 8
; CHECK-NEXT: [[I4277:%.*]] = load double, ptr [[IE]], align 8
; CHECK-NEXT: [[I4326:%.*]] = load <4 x double>, ptr [[X]], align 8
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x double> poison, double [[I4238]], i32 0
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x double> [[TMP2]], double [[I4252]], i32 1
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast <2 x double> [[TMP1]], [[TMP3]]
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x double> poison, double [[I4264]], i32 0
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x double> [[TMP6]], double [[I4277]], i32 1
; CHECK-NEXT: [[TMP8:%.*]] = fmul fast <2 x double> [[TMP5]], [[TMP7]]
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x double> [[TMP4]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x double> [[TMP8]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[I44281:%.*]] = shufflevector <4 x double> [[TMP9]], <4 x double> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
; CHECK-NEXT: ret <4 x double> [[I44281]]
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[I4326]], <4 x double> poison, <2 x i32> <i32 0, i32 poison>
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x double> [[TMP1]], double [[I4275]], i32 1
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <2 x double> [[TMP2]], <2 x double> poison, <4 x i32> <i32 0, i32 0, i32 0, i32 1>
; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x double> poison, double [[I4238]], i32 0
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP4]], double [[I4252]], i32 1
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x double> [[TMP5]], double [[I4264]], i32 2
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP6]], double [[I4277]], i32 3
; CHECK-NEXT: [[TMP8:%.*]] = fmul fast <4 x double> [[TMP3]], [[TMP7]]
; CHECK-NEXT: ret <4 x double> [[TMP8]]
;
%i4238 = load double, ptr %ia, align 8
%i4252 = load double, ptr %ib, align 8

View File

@@ -49,24 +49,10 @@ define i32 @reduce_and4(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i32> %v3, <
;
; AVX512-LABEL: @reduce_and4(
; AVX512-NEXT: entry:
; AVX512-NEXT: [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
; AVX512-NEXT: [[VECEXT1:%.*]] = extractelement <4 x i32> [[V1]], i64 1
; AVX512-NEXT: [[VECEXT2:%.*]] = extractelement <4 x i32> [[V1]], i64 2
; AVX512-NEXT: [[VECEXT4:%.*]] = extractelement <4 x i32> [[V1]], i64 3
; AVX512-NEXT: [[VECEXT7:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
; AVX512-NEXT: [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
; AVX512-NEXT: [[VECEXT10:%.*]] = extractelement <4 x i32> [[V2]], i64 2
; AVX512-NEXT: [[VECEXT12:%.*]] = extractelement <4 x i32> [[V2]], i64 3
; AVX512-NEXT: [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
; AVX512-NEXT: [[TMP1:%.*]] = insertelement <16 x i32> [[TMP0]], i32 [[VECEXT8]], i32 8
; AVX512-NEXT: [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT7]], i32 9
; AVX512-NEXT: [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT10]], i32 10
; AVX512-NEXT: [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT12]], i32 11
; AVX512-NEXT: [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 12
; AVX512-NEXT: [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT]], i32 13
; AVX512-NEXT: [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT2]], i32 14
; AVX512-NEXT: [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT4]], i32 15
; AVX512-NEXT: [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP8]])
; AVX512-NEXT: [[TMP0:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
; AVX512-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 1, i32 0, i32 2, i32 3, i32 5, i32 4, i32 6, i32 7>
; AVX512-NEXT: [[RDX_OP:%.*]] = and <8 x i32> [[TMP0]], [[TMP1]]
; AVX512-NEXT: [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
; AVX512-NEXT: [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
; AVX512-NEXT: ret i32 [[OP_RDX1]]
;
@@ -144,24 +130,10 @@ define i32 @reduce_and4_transpose(i32 %acc, <4 x i32> %v1, <4 x i32> %v2, <4 x i
; AVX2-NEXT: ret i32 [[OP_RDX]]
;
; AVX512-LABEL: @reduce_and4_transpose(
; AVX512-NEXT: [[VECEXT:%.*]] = extractelement <4 x i32> [[V1:%.*]], i64 0
; AVX512-NEXT: [[VECEXT1:%.*]] = extractelement <4 x i32> [[V2:%.*]], i64 0
; AVX512-NEXT: [[VECEXT7:%.*]] = extractelement <4 x i32> [[V1]], i64 1
; AVX512-NEXT: [[VECEXT8:%.*]] = extractelement <4 x i32> [[V2]], i64 1
; AVX512-NEXT: [[VECEXT15:%.*]] = extractelement <4 x i32> [[V1]], i64 2
; AVX512-NEXT: [[VECEXT16:%.*]] = extractelement <4 x i32> [[V2]], i64 2
; AVX512-NEXT: [[VECEXT23:%.*]] = extractelement <4 x i32> [[V1]], i64 3
; AVX512-NEXT: [[VECEXT24:%.*]] = extractelement <4 x i32> [[V2]], i64 3
; AVX512-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <16 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
; AVX512-NEXT: [[TMP2:%.*]] = insertelement <16 x i32> [[TMP1]], i32 [[VECEXT24]], i32 8
; AVX512-NEXT: [[TMP3:%.*]] = insertelement <16 x i32> [[TMP2]], i32 [[VECEXT16]], i32 9
; AVX512-NEXT: [[TMP4:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[VECEXT8]], i32 10
; AVX512-NEXT: [[TMP5:%.*]] = insertelement <16 x i32> [[TMP4]], i32 [[VECEXT1]], i32 11
; AVX512-NEXT: [[TMP6:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[VECEXT23]], i32 12
; AVX512-NEXT: [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[VECEXT15]], i32 13
; AVX512-NEXT: [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[VECEXT7]], i32 14
; AVX512-NEXT: [[TMP9:%.*]] = insertelement <16 x i32> [[TMP8]], i32 [[VECEXT]], i32 15
; AVX512-NEXT: [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP9]])
; AVX512-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[V4:%.*]], <4 x i32> [[V3:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
; AVX512-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[V2:%.*]], <4 x i32> [[V1:%.*]], <8 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4>
; AVX512-NEXT: [[RDX_OP:%.*]] = and <8 x i32> [[TMP1]], [[TMP2]]
; AVX512-NEXT: [[OP_RDX:%.*]] = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> [[RDX_OP]])
; AVX512-NEXT: [[OP_RDX1:%.*]] = and i32 [[OP_RDX]], [[ACC:%.*]]
; AVX512-NEXT: ret i32 [[OP_RDX1]]
;