[mlir][VectorToLLVM] Fix bug in lowering of vector.reduce fmax/fmin
The lowering of fmax/fmin reduce was ignoring the optional accumulator. Differential Revision: https://reviews.llvm.org/D129597
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -393,6 +394,27 @@ static Value createIntegerReductionComparisonOpLowering(
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
|
||||
/// with vector types.
|
||||
static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
|
||||
Value rhs, bool isMin) {
|
||||
auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
|
||||
Type i1Type = builder.getI1Type();
|
||||
if (auto vecType = lhs.getType().dyn_cast<VectorType>())
|
||||
i1Type = VectorType::get(vecType.getShape(), i1Type);
|
||||
Value cmp = builder.create<LLVM::FCmpOp>(
|
||||
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
|
||||
lhs, rhs);
|
||||
Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
|
||||
Value isNan = builder.create<LLVM::FCmpOp>(
|
||||
loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
|
||||
Value nan = builder.create<LLVM::ConstantOp>(
|
||||
loc, lhs.getType(),
|
||||
builder.getFloatAttr(floatType,
|
||||
APFloat::getQNaN(floatType.getFloatSemantics())));
|
||||
return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
|
||||
}
|
||||
|
||||
/// Conversion pattern for all vector reductions.
|
||||
class VectorReductionOpConversion
|
||||
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
|
||||
@@ -497,18 +519,25 @@ public:
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
|
||||
reductionOp, llvmType, acc, operand,
|
||||
rewriter.getBoolAttr(reassociateFPReductions));
|
||||
} else if (kind == vector::CombiningKind::MINF)
|
||||
} else if (kind == vector::CombiningKind::MINF) {
|
||||
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
|
||||
// NaNs/-0.0/+0.0 in the same way.
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
|
||||
llvmType, operand);
|
||||
else if (kind == vector::CombiningKind::MAXF)
|
||||
Value result =
|
||||
rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
|
||||
if (acc)
|
||||
result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
|
||||
rewriter.replaceOp(reductionOp, result);
|
||||
} else if (kind == vector::CombiningKind::MAXF) {
|
||||
// FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
|
||||
// NaNs/-0.0/+0.0 in the same way.
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
|
||||
llvmType, operand);
|
||||
else
|
||||
Value result =
|
||||
rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
|
||||
if (acc)
|
||||
result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
|
||||
rewriter.replaceOp(reductionOp, result);
|
||||
} else
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user