[mlir:OpConversion] Remove the remaing usages of the deprecated matchAndRewrite methods

This commits updates the remaining usages of the ArrayRef<Value> based
matchAndRewrite/rewrite methods in favor of the new OpAdaptor
overload.

Differential Revision: https://reviews.llvm.org/D110360
This commit is contained in:
River Riddle
2021-09-24 17:51:20 +00:00
parent b54c724be0
commit ef976337f5
21 changed files with 233 additions and 296 deletions

View File

@@ -178,7 +178,7 @@ public:
using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only 1-D vectors can be lowered to LLVM.
VectorType resultTy = bitCastOp.getType();
@@ -186,7 +186,7 @@ public:
return failure();
Type newResultTy = typeConverter->convertType(resultTy);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
operands[0]);
adaptor.getOperands()[0]);
return success();
}
};
@@ -199,9 +199,8 @@ public:
using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::MatmulOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
matmulOp, typeConverter->convertType(matmulOp.res().getType()),
adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
@@ -218,9 +217,8 @@ public:
using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::FlatTransposeOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
transOp, typeConverter->convertType(transOp.res().getType()),
adaptor.matrix(), transOp.rows(), transOp.columns());
@@ -270,7 +268,8 @@ public:
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
typename LoadOrStoreOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only 1-D vectors can be lowered to LLVM.
VectorType vectorTy = loadOrStoreOp.getVectorType();
@@ -278,7 +277,6 @@ public:
return failure();
auto loc = loadOrStoreOp->getLoc();
auto adaptor = LoadOrStoreOpAdaptor(operands);
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
// Resolve alignment.
@@ -306,10 +304,9 @@ public:
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = gather->getLoc();
auto adaptor = vector::GatherOpAdaptor(operands);
MemRefType memRefType = gather.getMemRefType();
// Resolve alignment.
@@ -341,10 +338,9 @@ public:
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
auto adaptor = vector::ScatterOpAdaptor(operands);
MemRefType memRefType = scatter.getMemRefType();
// Resolve alignment.
@@ -376,10 +372,9 @@ public:
using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = expand->getLoc();
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
MemRefType memRefType = expand.getMemRefType();
// Resolve address.
@@ -400,10 +395,9 @@ public:
using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = compress->getLoc();
auto adaptor = vector::CompressStoreOpAdaptor(operands);
MemRefType memRefType = compress.getMemRefType();
// Resolve address.
@@ -426,42 +420,43 @@ public:
reassociateFPReductions(reassociateFPRed) {}
LogicalResult
matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
Type llvmType = typeConverter->convertType(eltType);
Value operand = adaptor.getOperands()[0];
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
reductionOp, llvmType, operands[0]);
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
llvmType, operand);
else if (kind == "mul")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
reductionOp, llvmType, operands[0]);
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
llvmType, operand);
else if (kind == "min" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
reductionOp, llvmType, operands[0]);
reductionOp, llvmType, operand);
else if (kind == "min")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
reductionOp, llvmType, operands[0]);
reductionOp, llvmType, operand);
else if (kind == "max" &&
(eltType.isIndex() || eltType.isUnsignedInteger()))
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
reductionOp, llvmType, operands[0]);
reductionOp, llvmType, operand);
else if (kind == "max")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
reductionOp, llvmType, operands[0]);
reductionOp, llvmType, operand);
else if (kind == "and")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
reductionOp, llvmType, operands[0]);
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
llvmType, operand);
else if (kind == "or")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
reductionOp, llvmType, operands[0]);
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
llvmType, operand);
else if (kind == "xor")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
reductionOp, llvmType, operands[0]);
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
llvmType, operand);
else
return failure();
return success();
@@ -473,29 +468,30 @@ public:
// Floating-point reductions: add/mul/min/max
if (kind == "add") {
// Optional accumulator (or zero).
Value acc = operands.size() > 1 ? operands[1]
: rewriter.create<LLVM::ConstantOp>(
reductionOp->getLoc(), llvmType,
rewriter.getZeroAttr(eltType));
Value acc = adaptor.getOperands().size() > 1
? adaptor.getOperands()[1]
: rewriter.create<LLVM::ConstantOp>(
reductionOp->getLoc(), llvmType,
rewriter.getZeroAttr(eltType));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
reductionOp, llvmType, acc, operands[0],
reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "mul") {
// Optional accumulator (or one).
Value acc = operands.size() > 1
? operands[1]
Value acc = adaptor.getOperands().size() > 1
? adaptor.getOperands()[1]
: rewriter.create<LLVM::ConstantOp>(
reductionOp->getLoc(), llvmType,
rewriter.getFloatAttr(eltType, 1.0));
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
reductionOp, llvmType, acc, operands[0],
reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
} else if (kind == "min")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
reductionOp, llvmType, operands[0]);
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
llvmType, operand);
else if (kind == "max")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
reductionOp, llvmType, operands[0]);
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
llvmType, operand);
else
return failure();
return success();
@@ -511,10 +507,9 @@ public:
using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = shuffleOp->getLoc();
auto adaptor = vector::ShuffleOpAdaptor(operands);
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
@@ -573,10 +568,8 @@ public:
vector::ExtractElementOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::ExtractElementOp extractEltOp,
ArrayRef<Value> operands,
matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpAdaptor(operands);
auto vectorType = extractEltOp.getVectorType();
auto llvmType = typeConverter->convertType(vectorType.getElementType());
@@ -596,10 +589,9 @@ public:
using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = extractOp->getLoc();
auto adaptor = vector::ExtractOpAdaptor(operands);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
@@ -667,9 +659,8 @@ public:
using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::FMAOpAdaptor(operands);
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return failure();
@@ -685,9 +676,8 @@ public:
using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpAdaptor(operands);
auto vectorType = insertEltOp.getDestVectorType();
auto llvmType = typeConverter->convertType(vectorType);
@@ -708,10 +698,9 @@ public:
using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = insertOp->getLoc();
auto adaptor = vector::InsertOpAdaptor(operands);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
@@ -984,7 +973,7 @@ public:
using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = castOp->getLoc();
MemRefType sourceMemRefType =
@@ -997,10 +986,10 @@ public:
return failure();
auto llvmSourceDescriptorTy =
operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
if (!llvmSourceDescriptorTy)
return failure();
MemRefDescriptor sourceMemRef(operands[0]);
MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMStructType>();
@@ -1074,9 +1063,8 @@ public:
// TODO: rely solely on libc in future? something else?
//
LogicalResult
matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::PrintOpAdaptor(operands);
Type printType = printOp.getPrintType();
if (typeConverter->convertType(printType) == nullptr)