[mlir:OpConversion] Remove the remaing usages of the deprecated matchAndRewrite methods
This commits updates the remaining usages of the ArrayRef<Value> based matchAndRewrite/rewrite methods in favor of the new OpAdaptor overload. Differential Revision: https://reviews.llvm.org/D110360
This commit is contained in:
@@ -178,7 +178,7 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only 1-D vectors can be lowered to LLVM.
|
||||
VectorType resultTy = bitCastOp.getType();
|
||||
@@ -186,7 +186,7 @@ public:
|
||||
return failure();
|
||||
Type newResultTy = typeConverter->convertType(resultTy);
|
||||
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
|
||||
operands[0]);
|
||||
adaptor.getOperands()[0]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -199,9 +199,8 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::MatmulOpAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
|
||||
matmulOp, typeConverter->convertType(matmulOp.res().getType()),
|
||||
adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
|
||||
@@ -218,9 +217,8 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::FlatTransposeOpAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
|
||||
transOp, typeConverter->convertType(transOp.res().getType()),
|
||||
adaptor.matrix(), transOp.rows(), transOp.columns());
|
||||
@@ -270,7 +268,8 @@ public:
|
||||
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
|
||||
typename LoadOrStoreOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only 1-D vectors can be lowered to LLVM.
|
||||
VectorType vectorTy = loadOrStoreOp.getVectorType();
|
||||
@@ -278,7 +277,6 @@ public:
|
||||
return failure();
|
||||
|
||||
auto loc = loadOrStoreOp->getLoc();
|
||||
auto adaptor = LoadOrStoreOpAdaptor(operands);
|
||||
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
@@ -306,10 +304,9 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = gather->getLoc();
|
||||
auto adaptor = vector::GatherOpAdaptor(operands);
|
||||
MemRefType memRefType = gather.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
@@ -341,10 +338,9 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = scatter->getLoc();
|
||||
auto adaptor = vector::ScatterOpAdaptor(operands);
|
||||
MemRefType memRefType = scatter.getMemRefType();
|
||||
|
||||
// Resolve alignment.
|
||||
@@ -376,10 +372,9 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = expand->getLoc();
|
||||
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
|
||||
MemRefType memRefType = expand.getMemRefType();
|
||||
|
||||
// Resolve address.
|
||||
@@ -400,10 +395,9 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = compress->getLoc();
|
||||
auto adaptor = vector::CompressStoreOpAdaptor(operands);
|
||||
MemRefType memRefType = compress.getMemRefType();
|
||||
|
||||
// Resolve address.
|
||||
@@ -426,42 +420,43 @@ public:
|
||||
reassociateFPReductions(reassociateFPRed) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto kind = reductionOp.kind();
|
||||
Type eltType = reductionOp.dest().getType();
|
||||
Type llvmType = typeConverter->convertType(eltType);
|
||||
Value operand = adaptor.getOperands()[0];
|
||||
if (eltType.isIntOrIndex()) {
|
||||
// Integer reductions: add/mul/min/max/and/or/xor.
|
||||
if (kind == "add")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
|
||||
llvmType, operand);
|
||||
else if (kind == "mul")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
|
||||
llvmType, operand);
|
||||
else if (kind == "min" &&
|
||||
(eltType.isIndex() || eltType.isUnsignedInteger()))
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operand);
|
||||
else if (kind == "min")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operand);
|
||||
else if (kind == "max" &&
|
||||
(eltType.isIndex() || eltType.isUnsignedInteger()))
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operand);
|
||||
else if (kind == "max")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
reductionOp, llvmType, operand);
|
||||
else if (kind == "and")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
|
||||
llvmType, operand);
|
||||
else if (kind == "or")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
|
||||
llvmType, operand);
|
||||
else if (kind == "xor")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
|
||||
llvmType, operand);
|
||||
else
|
||||
return failure();
|
||||
return success();
|
||||
@@ -473,29 +468,30 @@ public:
|
||||
// Floating-point reductions: add/mul/min/max
|
||||
if (kind == "add") {
|
||||
// Optional accumulator (or zero).
|
||||
Value acc = operands.size() > 1 ? operands[1]
|
||||
: rewriter.create<LLVM::ConstantOp>(
|
||||
reductionOp->getLoc(), llvmType,
|
||||
rewriter.getZeroAttr(eltType));
|
||||
Value acc = adaptor.getOperands().size() > 1
|
||||
? adaptor.getOperands()[1]
|
||||
: rewriter.create<LLVM::ConstantOp>(
|
||||
reductionOp->getLoc(), llvmType,
|
||||
rewriter.getZeroAttr(eltType));
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
|
||||
reductionOp, llvmType, acc, operands[0],
|
||||
reductionOp, llvmType, acc, operand,
|
||||
rewriter.getBoolAttr(reassociateFPReductions));
|
||||
} else if (kind == "mul") {
|
||||
// Optional accumulator (or one).
|
||||
Value acc = operands.size() > 1
|
||||
? operands[1]
|
||||
Value acc = adaptor.getOperands().size() > 1
|
||||
? adaptor.getOperands()[1]
|
||||
: rewriter.create<LLVM::ConstantOp>(
|
||||
reductionOp->getLoc(), llvmType,
|
||||
rewriter.getFloatAttr(eltType, 1.0));
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
|
||||
reductionOp, llvmType, acc, operands[0],
|
||||
reductionOp, llvmType, acc, operand,
|
||||
rewriter.getBoolAttr(reassociateFPReductions));
|
||||
} else if (kind == "min")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
|
||||
llvmType, operand);
|
||||
else if (kind == "max")
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
|
||||
reductionOp, llvmType, operands[0]);
|
||||
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
|
||||
llvmType, operand);
|
||||
else
|
||||
return failure();
|
||||
return success();
|
||||
@@ -511,10 +507,9 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = shuffleOp->getLoc();
|
||||
auto adaptor = vector::ShuffleOpAdaptor(operands);
|
||||
auto v1Type = shuffleOp.getV1VectorType();
|
||||
auto v2Type = shuffleOp.getV2VectorType();
|
||||
auto vectorType = shuffleOp.getVectorType();
|
||||
@@ -573,10 +568,8 @@ public:
|
||||
vector::ExtractElementOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::ExtractElementOp extractEltOp,
|
||||
ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::ExtractElementOpAdaptor(operands);
|
||||
auto vectorType = extractEltOp.getVectorType();
|
||||
auto llvmType = typeConverter->convertType(vectorType.getElementType());
|
||||
|
||||
@@ -596,10 +589,9 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = extractOp->getLoc();
|
||||
auto adaptor = vector::ExtractOpAdaptor(operands);
|
||||
auto vectorType = extractOp.getVectorType();
|
||||
auto resultType = extractOp.getResult().getType();
|
||||
auto llvmResultType = typeConverter->convertType(resultType);
|
||||
@@ -667,9 +659,8 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::FMAOpAdaptor(operands);
|
||||
VectorType vType = fmaOp.getVectorType();
|
||||
if (vType.getRank() != 1)
|
||||
return failure();
|
||||
@@ -685,9 +676,8 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::InsertElementOpAdaptor(operands);
|
||||
auto vectorType = insertEltOp.getDestVectorType();
|
||||
auto llvmType = typeConverter->convertType(vectorType);
|
||||
|
||||
@@ -708,10 +698,9 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = insertOp->getLoc();
|
||||
auto adaptor = vector::InsertOpAdaptor(operands);
|
||||
auto sourceType = insertOp.getSourceType();
|
||||
auto destVectorType = insertOp.getDestVectorType();
|
||||
auto llvmResultType = typeConverter->convertType(destVectorType);
|
||||
@@ -984,7 +973,7 @@ public:
|
||||
using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = castOp->getLoc();
|
||||
MemRefType sourceMemRefType =
|
||||
@@ -997,10 +986,10 @@ public:
|
||||
return failure();
|
||||
|
||||
auto llvmSourceDescriptorTy =
|
||||
operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
if (!llvmSourceDescriptorTy)
|
||||
return failure();
|
||||
MemRefDescriptor sourceMemRef(operands[0]);
|
||||
MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
|
||||
|
||||
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMStructType>();
|
||||
@@ -1074,9 +1063,8 @@ public:
|
||||
// TODO: rely solely on libc in future? something else?
|
||||
//
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
|
||||
matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::PrintOpAdaptor(operands);
|
||||
Type printType = printOp.getPrintType();
|
||||
|
||||
if (typeConverter->convertType(printType) == nullptr)
|
||||
|
||||
Reference in New Issue
Block a user