[LoopVectorize] improve IR fast-math-flags propagation in reductions
This is another step (see D95452) towards correcting fast-math-flags bugs in vector reductions. There are multiple bugs visible in the test diffs, and this is still not working as it should. We still use function attributes (rather than FMF) to drive part of the logic, but we are not checking for the correct FP function attributes. Note that FMF may not be propagated optimally on selects (example in https://llvm.org/PR35607 ). That's why I'm proposing to union the FMF of a fcmp+select pair and avoid regressions on existing vectorizer tests. Differential Revision: https://reviews.llvm.org/D95690
This commit is contained in:
@@ -239,6 +239,9 @@ public:
|
||||
void operator&=(const FastMathFlags &OtherFlags) {
|
||||
Flags &= OtherFlags.Flags;
|
||||
}
|
||||
void operator|=(const FastMathFlags &OtherFlags) {
|
||||
Flags |= OtherFlags.Flags;
|
||||
}
|
||||
};
|
||||
|
||||
/// Utility class for floating point operations which can have
|
||||
|
||||
@@ -356,6 +356,7 @@ bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
|
||||
OptimizationRemarkEmitter *ORE = nullptr);
|
||||
|
||||
/// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind.
|
||||
/// The Builder's fast-math-flags must be set to propagate the expected values.
|
||||
Value *createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
|
||||
Value *Right);
|
||||
|
||||
|
||||
@@ -302,8 +302,18 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
|
||||
if (!ReduxDesc.isRecurrence())
|
||||
return false;
|
||||
// FIXME: FMF is allowed on phi, but propagation is not handled correctly.
|
||||
if (isa<FPMathOperator>(ReduxDesc.getPatternInst()) && !IsAPhi)
|
||||
FMF &= ReduxDesc.getPatternInst()->getFastMathFlags();
|
||||
if (isa<FPMathOperator>(ReduxDesc.getPatternInst()) && !IsAPhi) {
|
||||
FastMathFlags CurFMF = ReduxDesc.getPatternInst()->getFastMathFlags();
|
||||
if (auto *Sel = dyn_cast<SelectInst>(ReduxDesc.getPatternInst())) {
|
||||
// Accept FMF on either fcmp or select of a min/max idiom.
|
||||
// TODO: This is a hack to work-around the fact that FMF may not be
|
||||
// assigned/propagated correctly. If that problem is fixed or we
|
||||
// standardize on fmin/fmax via intrinsics, this can be removed.
|
||||
assert(isa<FCmpInst>(Sel->getCondition()) && "Expected fcmp min/max");
|
||||
CurFMF |= cast<FCmpInst>(Sel->getCondition())->getFastMathFlags();
|
||||
}
|
||||
FMF &= CurFMF;
|
||||
}
|
||||
// Update this reduction kind if we matched a new instruction.
|
||||
// TODO: Can we eliminate the need for a 2nd InstDesc by keeping 'Kind'
|
||||
// state accurate while processing the worklist?
|
||||
|
||||
@@ -944,12 +944,6 @@ Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
|
||||
break;
|
||||
}
|
||||
|
||||
// We only match FP sequences that are 'fast', so we can unconditionally
|
||||
// set it on any generated instructions.
|
||||
IRBuilderBase::FastMathFlagGuard FMFG(Builder);
|
||||
FastMathFlags FMF;
|
||||
FMF.setFast();
|
||||
Builder.setFastMathFlags(FMF);
|
||||
Value *Cmp = Builder.CreateCmp(Pred, Left, Right, "rdx.minmax.cmp");
|
||||
Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select");
|
||||
return Select;
|
||||
|
||||
@@ -403,12 +403,6 @@ static Value *addFastMathFlag(Value *V) {
|
||||
return V;
|
||||
}
|
||||
|
||||
static Value *addFastMathFlag(Value *V, FastMathFlags FMF) {
|
||||
if (isa<FPMathOperator>(V))
|
||||
cast<Instruction>(V)->setFastMathFlags(FMF);
|
||||
return V;
|
||||
}
|
||||
|
||||
/// A helper function that returns an integer or floating-point constant with
|
||||
/// value C.
|
||||
static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) {
|
||||
@@ -4301,16 +4295,19 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) {
|
||||
// terminate on this line. This is the easiest way to ensure we don't
|
||||
// accidentally cause an extra step back into the loop while debugging.
|
||||
setDebugLocFromInst(Builder, LoopMiddleBlock->getTerminator());
|
||||
for (unsigned Part = 1; Part < UF; ++Part) {
|
||||
Value *RdxPart = VectorLoopValueMap.getVectorValue(LoopExitInst, Part);
|
||||
if (Op != Instruction::ICmp && Op != Instruction::FCmp)
|
||||
// Floating point operations had to be 'fast' to enable the reduction.
|
||||
ReducedPartRdx = addFastMathFlag(
|
||||
Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxPart,
|
||||
ReducedPartRdx, "bin.rdx"),
|
||||
RdxDesc.getFastMathFlags());
|
||||
else
|
||||
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
|
||||
{
|
||||
// Floating-point operations should have some FMF to enable the reduction.
|
||||
IRBuilderBase::FastMathFlagGuard FMFG(Builder);
|
||||
Builder.setFastMathFlags(RdxDesc.getFastMathFlags());
|
||||
for (unsigned Part = 1; Part < UF; ++Part) {
|
||||
Value *RdxPart = VectorLoopValueMap.getVectorValue(LoopExitInst, Part);
|
||||
if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
|
||||
ReducedPartRdx = Builder.CreateBinOp(
|
||||
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
|
||||
} else {
|
||||
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the reduction after the loop. Note that inloop reductions create the
|
||||
|
||||
@@ -262,7 +262,8 @@ loop.exit:
|
||||
ret float %sum.lcssa
|
||||
}
|
||||
|
||||
; FIXME: Some fcmp are 'nnan ninf', some are 'fast', but the reduction is sequential?
|
||||
; New instructions should have the same FMF as the original code.
|
||||
; Note that the select inherits FMF from its fcmp condition.
|
||||
|
||||
define float @PR35538(float* nocapture readonly %a, i32 %N) #0 {
|
||||
; CHECK-LABEL: @PR35538(
|
||||
@@ -299,9 +300,9 @@ define float @PR35538(float* nocapture readonly %a, i32 %N) #0 {
|
||||
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
|
||||
; CHECK-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP8:!llvm.loop !.*]]
|
||||
; CHECK: middle.block:
|
||||
; CHECK-NEXT: [[RDX_MINMAX_CMP:%.*]] = fcmp fast ogt <4 x float> [[TMP10]], [[TMP11]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP10]], <4 x float> [[TMP11]]
|
||||
; CHECK-NEXT: [[TMP13:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[RDX_MINMAX_SELECT]])
|
||||
; CHECK-NEXT: [[RDX_MINMAX_CMP:%.*]] = fcmp nnan ninf ogt <4 x float> [[TMP10]], [[TMP11]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select nnan ninf <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP10]], <4 x float> [[TMP11]]
|
||||
; CHECK-NEXT: [[TMP13:%.*]] = call nnan ninf float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[RDX_MINMAX_SELECT]])
|
||||
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
|
||||
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
|
||||
; CHECK: scalar.ph:
|
||||
@@ -349,6 +350,8 @@ for.body:
|
||||
br i1 %exitcond, label %for.cond.cleanup, label %for.body
|
||||
}
|
||||
|
||||
; Same as above, but this time the select already has matching FMF with its condition.
|
||||
|
||||
define float @PR35538_more_FMF(float* nocapture readonly %a, i32 %N) #0 {
|
||||
; CHECK-LABEL: @PR35538_more_FMF(
|
||||
; CHECK-NEXT: entry:
|
||||
@@ -384,8 +387,8 @@ define float @PR35538_more_FMF(float* nocapture readonly %a, i32 %N) #0 {
|
||||
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
|
||||
; CHECK-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP10:!llvm.loop !.*]]
|
||||
; CHECK: middle.block:
|
||||
; CHECK-NEXT: [[RDX_MINMAX_CMP:%.*]] = fcmp fast ogt <4 x float> [[TMP10]], [[TMP11]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP10]], <4 x float> [[TMP11]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_CMP:%.*]] = fcmp nnan ninf ogt <4 x float> [[TMP10]], [[TMP11]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select nnan ninf <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP10]], <4 x float> [[TMP11]]
|
||||
; CHECK-NEXT: [[TMP13:%.*]] = call nnan ninf float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[RDX_MINMAX_SELECT]])
|
||||
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
|
||||
; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]]
|
||||
|
||||
@@ -69,11 +69,11 @@ define float @minloopattr(float* nocapture readonly %arg) #0 {
|
||||
; CHECK-NEXT: br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], [[LOOP0:!llvm.loop !.*]]
|
||||
; CHECK: middle.block:
|
||||
; CHECK-NEXT: [[RDX_SHUF:%.*]] = shufflevector <4 x float> [[TMP5]], <4 x float> poison, <4 x i32> <i32 2, i32 3, i32 undef, i32 undef>
|
||||
; CHECK-NEXT: [[RDX_MINMAX_CMP:%.*]] = fcmp fast olt <4 x float> [[TMP5]], [[RDX_SHUF]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP5]], <4 x float> [[RDX_SHUF]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_CMP:%.*]] = fcmp olt <4 x float> [[TMP5]], [[RDX_SHUF]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_SELECT:%.*]] = select <4 x i1> [[RDX_MINMAX_CMP]], <4 x float> [[TMP5]], <4 x float> [[RDX_SHUF]]
|
||||
; CHECK-NEXT: [[RDX_SHUF1:%.*]] = shufflevector <4 x float> [[RDX_MINMAX_SELECT]], <4 x float> poison, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
|
||||
; CHECK-NEXT: [[RDX_MINMAX_CMP2:%.*]] = fcmp fast olt <4 x float> [[RDX_MINMAX_SELECT]], [[RDX_SHUF1]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_SELECT3:%.*]] = select fast <4 x i1> [[RDX_MINMAX_CMP2]], <4 x float> [[RDX_MINMAX_SELECT]], <4 x float> [[RDX_SHUF1]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_CMP2:%.*]] = fcmp olt <4 x float> [[RDX_MINMAX_SELECT]], [[RDX_SHUF1]]
|
||||
; CHECK-NEXT: [[RDX_MINMAX_SELECT3:%.*]] = select <4 x i1> [[RDX_MINMAX_CMP2]], <4 x float> [[RDX_MINMAX_SELECT]], <4 x float> [[RDX_SHUF1]]
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x float> [[RDX_MINMAX_SELECT3]], i32 0
|
||||
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 65536, 65536
|
||||
; CHECK-NEXT: br i1 [[CMP_N]], label [[OUT:%.*]], label [[SCALAR_PH]]
|
||||
|
||||
Reference in New Issue
Block a user