[mlir][NFC] Replace all usages of PatternMatchResult with LogicalResult
This also replaces usages of matchSuccess/matchFailure with success/failure respectively. Differential Revision: https://reviews.llvm.org/D76313
This commit is contained in:
@@ -133,13 +133,13 @@ public:
|
||||
: ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto broadcastOp = cast<vector::BroadcastOp>(op);
|
||||
VectorType dstVectorType = broadcastOp.getVectorType();
|
||||
if (typeConverter.convertType(dstVectorType) == nullptr)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// Rewrite when the full vector type can be lowered (which
|
||||
// implies all 'reduced' types can be lowered too).
|
||||
auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
|
||||
@@ -149,7 +149,7 @@ public:
|
||||
op, expandRanks(adaptor.source(), // source value to be expanded
|
||||
op->getLoc(), // location of original broadcast
|
||||
srcVectorType, dstVectorType, rewriter));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -284,7 +284,7 @@ public:
|
||||
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto matmulOp = cast<vector::MatmulOp>(op);
|
||||
@@ -293,7 +293,7 @@ public:
|
||||
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
|
||||
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
|
||||
matmulOp.rhs_columns());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -304,7 +304,7 @@ public:
|
||||
: ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto reductionOp = cast<vector::ReductionOp>(op);
|
||||
@@ -335,8 +335,8 @@ public:
|
||||
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
|
||||
op, llvmType, operands[0]);
|
||||
else
|
||||
return matchFailure();
|
||||
return matchSuccess();
|
||||
return failure();
|
||||
return success();
|
||||
|
||||
} else if (eltType.isF32() || eltType.isF64()) {
|
||||
// Floating-point reductions: add/mul/min/max
|
||||
@@ -364,10 +364,10 @@ public:
|
||||
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
|
||||
op, llvmType, operands[0]);
|
||||
else
|
||||
return matchFailure();
|
||||
return matchSuccess();
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -378,7 +378,7 @@ public:
|
||||
: ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
@@ -392,7 +392,7 @@ public:
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Get rank and dimension sizes.
|
||||
int64_t rank = vectorType.getRank();
|
||||
@@ -406,7 +406,7 @@ public:
|
||||
Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
|
||||
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
|
||||
rewriter.replaceOp(op, shuffle);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// For all other cases, insert the individual values individually.
|
||||
@@ -425,7 +425,7 @@ public:
|
||||
llvmType, rank, insPos++);
|
||||
}
|
||||
rewriter.replaceOp(op, insert);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -436,7 +436,7 @@ public:
|
||||
: ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
|
||||
@@ -446,11 +446,11 @@ public:
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
||||
op, llvmType, adaptor.vector(), adaptor.position());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -461,7 +461,7 @@ public:
|
||||
: ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
@@ -474,14 +474,14 @@ public:
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// One-shot extraction of vector from array (only requires extractvalue).
|
||||
if (resultType.isa<VectorType>()) {
|
||||
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
|
||||
rewriter.replaceOp(op, extracted);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
@@ -505,7 +505,7 @@ public:
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
|
||||
rewriter.replaceOp(op, extracted);
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -530,17 +530,17 @@ public:
|
||||
: ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::FMAOpOperandAdaptor(operands);
|
||||
vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
|
||||
VectorType vType = fmaOp.getVectorType();
|
||||
if (vType.getRank() != 1)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
|
||||
adaptor.acc());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -551,7 +551,7 @@ public:
|
||||
: ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
|
||||
@@ -561,11 +561,11 @@ public:
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
||||
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -576,7 +576,7 @@ public:
|
||||
: ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
@@ -589,7 +589,7 @@ public:
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// One-shot insertion of a vector into an array (only requires insertvalue).
|
||||
if (sourceType.isa<VectorType>()) {
|
||||
@@ -597,7 +597,7 @@ public:
|
||||
loc, llvmResultType, adaptor.dest(), adaptor.source(),
|
||||
positionArrayAttr);
|
||||
rewriter.replaceOp(op, inserted);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
@@ -632,7 +632,7 @@ public:
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, inserted);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -661,11 +661,11 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
|
||||
public:
|
||||
using OpRewritePattern<FMAOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(FMAOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(FMAOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto vType = op.getVectorType();
|
||||
if (vType.getRank() < 2)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto elemType = vType.getElementType();
|
||||
@@ -680,7 +680,7 @@ public:
|
||||
desc = rewriter.create<InsertOp>(loc, fma, desc, i);
|
||||
}
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -704,19 +704,19 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff == 0)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
int64_t rankRest = dstType.getRank() - rankDiff;
|
||||
// Extract / insert the subvector of matching rank and InsertStridedSlice
|
||||
@@ -735,7 +735,7 @@ public:
|
||||
op, stridedSliceInnerOp.getResult(), op.dest(),
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||
/*dropFront=*/rankRest));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -753,22 +753,22 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
|
||||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff != 0)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
if (srcType == dstType) {
|
||||
rewriter.replaceOp(op, op.source());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t offset =
|
||||
@@ -813,7 +813,7 @@ public:
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -824,7 +824,7 @@ public:
|
||||
: ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
@@ -837,18 +837,18 @@ public:
|
||||
// Only static shape casts supported atm.
|
||||
if (!sourceMemRefType.hasStaticShape() ||
|
||||
!targetMemRefType.hasStaticShape())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto llvmSourceDescriptorTy =
|
||||
operands[0].getType().dyn_cast<LLVM::LLVMType>();
|
||||
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
MemRefDescriptor sourceMemRef(operands[0]);
|
||||
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
@@ -866,7 +866,7 @@ public:
|
||||
}
|
||||
// Only contiguous source tensors supported atm.
|
||||
if (failed(successStrides) || !isContiguous)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
|
||||
|
||||
@@ -901,7 +901,7 @@ public:
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -924,7 +924,7 @@ public:
|
||||
//
|
||||
// TODO(ajcbik): rely solely on libc in future? something else?
|
||||
//
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto printOp = cast<vector::PrintOp>(op);
|
||||
@@ -932,7 +932,7 @@ public:
|
||||
Type printType = printOp.getPrintType();
|
||||
|
||||
if (typeConverter.convertType(printType) == nullptr)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Make sure element type has runtime support (currently just Float/Double).
|
||||
VectorType vectorType = printType.dyn_cast<VectorType>();
|
||||
@@ -948,13 +948,13 @@ public:
|
||||
else if (eltType.isF64())
|
||||
printer = getPrintDouble(op);
|
||||
else
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Unroll vector into elementary print calls.
|
||||
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
|
||||
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -1047,8 +1047,8 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(StridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(StridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getResult().getType().cast<VectorType>();
|
||||
|
||||
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
|
||||
@@ -1089,7 +1089,7 @@ public:
|
||||
res = insertOne(rewriter, loc, extracted, res, idx);
|
||||
}
|
||||
rewriter.replaceOp(op, {res});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user