[mlir][vector][NFC] Use CombiningKindAttr instead of StringAttr

This makes the op consistent with other ops in vector dialect.

Differential Revision: https://reviews.llvm.org/D119343
This commit is contained in:
Matthias Springer
2022-02-10 19:12:46 +09:00
parent fd43d99c93
commit fe0bf7d469
25 changed files with 241 additions and 264 deletions

View File

@@ -380,31 +380,31 @@ public:
Value operand = adaptor.getOperands()[0];
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
if (kind == vector::CombiningKind::ADD)
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
llvmType, operand);
else if (kind == "mul")
else if (kind == vector::CombiningKind::MUL)
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
llvmType, operand);
else if (kind == "minui")
else if (kind == vector::CombiningKind::MINUI)
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
reductionOp, llvmType, operand);
else if (kind == "minsi")
else if (kind == vector::CombiningKind::MINSI)
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
reductionOp, llvmType, operand);
else if (kind == "maxui")
else if (kind == vector::CombiningKind::MAXUI)
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
reductionOp, llvmType, operand);
else if (kind == "maxsi")
else if (kind == vector::CombiningKind::MAXSI)
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
reductionOp, llvmType, operand);
else if (kind == "and")
else if (kind == vector::CombiningKind::AND)
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
llvmType, operand);
else if (kind == "or")
else if (kind == vector::CombiningKind::OR)
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
llvmType, operand);
else if (kind == "xor")
else if (kind == vector::CombiningKind::XOR)
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
llvmType, operand);
else
@@ -416,7 +416,7 @@ public:
return failure();
// Floating-point reductions: add/mul/min/max
if (kind == "add") {
if (kind == vector::CombiningKind::ADD) {
// Optional accumulator (or zero).
Value acc = adaptor.getOperands().size() > 1
? adaptor.getOperands()[1]
@@ -426,7 +426,7 @@ public:
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "mul") {
} else if (kind == vector::CombiningKind::MUL) {
// Optional accumulator (or one).
Value acc = adaptor.getOperands().size() > 1
? adaptor.getOperands()[1]
@@ -436,12 +436,12 @@ public:
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "minf")
} else if (kind == vector::CombiningKind::MINF)
// FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
// NaNs/-0.0/+0.0 in the same way.
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
llvmType, operand);
else if (kind == "maxf")
else if (kind == vector::CombiningKind::MAXF)
// FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
// NaNs/-0.0/+0.0 in the same way.
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,