[mlir][NFC] Replace all usages of PatternMatchResult with LogicalResult

This also replaces usages of matchSuccess/matchFailure with success/failure respectively.

Differential Revision: https://reviews.llvm.org/D76313
This commit is contained in:
River Riddle
2020-03-17 20:07:55 -07:00
parent 2fae7878d5
commit 3145427dd7
52 changed files with 722 additions and 743 deletions

View File

@@ -133,13 +133,13 @@ public:
: ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto broadcastOp = cast<vector::BroadcastOp>(op);
VectorType dstVectorType = broadcastOp.getVectorType();
if (typeConverter.convertType(dstVectorType) == nullptr)
return matchFailure();
return failure();
// Rewrite when the full vector type can be lowered (which
// implies all 'reduced' types can be lowered too).
auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
@@ -149,7 +149,7 @@ public:
op, expandRanks(adaptor.source(), // source value to be expanded
op->getLoc(), // location of original broadcast
srcVectorType, dstVectorType, rewriter));
return matchSuccess();
return success();
}
private:
@@ -284,7 +284,7 @@ public:
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto matmulOp = cast<vector::MatmulOp>(op);
@@ -293,7 +293,7 @@ public:
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
matmulOp.rhs_columns());
return matchSuccess();
return success();
}
};
@@ -304,7 +304,7 @@ public:
: ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto reductionOp = cast<vector::ReductionOp>(op);
@@ -335,8 +335,8 @@ public:
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
op, llvmType, operands[0]);
else
return matchFailure();
return matchSuccess();
return failure();
return success();
} else if (eltType.isF32() || eltType.isF64()) {
// Floating-point reductions: add/mul/min/max
@@ -364,10 +364,10 @@ public:
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
op, llvmType, operands[0]);
else
return matchFailure();
return matchSuccess();
return failure();
return success();
}
return matchFailure();
return failure();
}
};
@@ -378,7 +378,7 @@ public:
: ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -392,7 +392,7 @@ public:
// Bail if result type cannot be lowered.
if (!llvmType)
return matchFailure();
return failure();
// Get rank and dimension sizes.
int64_t rank = vectorType.getRank();
@@ -406,7 +406,7 @@ public:
Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
rewriter.replaceOp(op, shuffle);
return matchSuccess();
return success();
}
// For all other cases, insert the individual values individually.
@@ -425,7 +425,7 @@ public:
llvmType, rank, insPos++);
}
rewriter.replaceOp(op, insert);
return matchSuccess();
return success();
}
};
@@ -436,7 +436,7 @@ public:
: ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
context, typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
@@ -446,11 +446,11 @@ public:
// Bail if result type cannot be lowered.
if (!llvmType)
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
op, llvmType, adaptor.vector(), adaptor.position());
return matchSuccess();
return success();
}
};
@@ -461,7 +461,7 @@ public:
: ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -474,14 +474,14 @@ public:
// Bail if result type cannot be lowered.
if (!llvmResultType)
return matchFailure();
return failure();
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
return success();
}
// Potential extraction of 1-D vector from array.
@@ -505,7 +505,7 @@ public:
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
return matchSuccess();
return success();
}
};
@@ -530,17 +530,17 @@ public:
: ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::FMAOpOperandAdaptor(operands);
vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
adaptor.acc());
return matchSuccess();
return success();
}
};
@@ -551,7 +551,7 @@ public:
: ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
context, typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
@@ -561,11 +561,11 @@ public:
// Bail if result type cannot be lowered.
if (!llvmType)
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
return matchSuccess();
return success();
}
};
@@ -576,7 +576,7 @@ public:
: ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -589,7 +589,7 @@ public:
// Bail if result type cannot be lowered.
if (!llvmResultType)
return matchFailure();
return failure();
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
@@ -597,7 +597,7 @@ public:
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
rewriter.replaceOp(op, inserted);
return matchSuccess();
return success();
}
// Potential extraction of 1-D vector from array.
@@ -632,7 +632,7 @@ public:
}
rewriter.replaceOp(op, inserted);
return matchSuccess();
return success();
}
};
@@ -661,11 +661,11 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
public:
using OpRewritePattern<FMAOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(FMAOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(FMAOp op,
PatternRewriter &rewriter) const override {
auto vType = op.getVectorType();
if (vType.getRank() < 2)
return matchFailure();
return failure();
auto loc = op.getLoc();
auto elemType = vType.getElementType();
@@ -680,7 +680,7 @@ public:
desc = rewriter.create<InsertOp>(loc, fma, desc, i);
}
rewriter.replaceOp(op, desc);
return matchSuccess();
return success();
}
};
@@ -704,19 +704,19 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return matchFailure();
return failure();
auto loc = op.getLoc();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff == 0)
return matchFailure();
return failure();
int64_t rankRest = dstType.getRank() - rankDiff;
// Extract / insert the subvector of matching rank and InsertStridedSlice
@@ -735,7 +735,7 @@ public:
op, stridedSliceInnerOp.getResult(), op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
/*dropFront=*/rankRest));
return matchSuccess();
return success();
}
};
@@ -753,22 +753,22 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return matchFailure();
return failure();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff != 0)
return matchFailure();
return failure();
if (srcType == dstType) {
rewriter.replaceOp(op, op.source());
return matchSuccess();
return success();
}
int64_t offset =
@@ -813,7 +813,7 @@ public:
}
rewriter.replaceOp(op, res);
return matchSuccess();
return success();
}
};
@@ -824,7 +824,7 @@ public:
: ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@@ -837,18 +837,18 @@ public:
// Only static shape casts supported atm.
if (!sourceMemRefType.hasStaticShape() ||
!targetMemRefType.hasStaticShape())
return matchFailure();
return failure();
auto llvmSourceDescriptorTy =
operands[0].getType().dyn_cast<LLVM::LLVMType>();
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
return matchFailure();
return failure();
MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return matchFailure();
return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -866,7 +866,7 @@ public:
}
// Only contiguous source tensors supported atm.
if (failed(successStrides) || !isContiguous)
return matchFailure();
return failure();
auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
@@ -901,7 +901,7 @@ public:
}
rewriter.replaceOp(op, {desc});
return matchSuccess();
return success();
}
};
@@ -924,7 +924,7 @@ public:
//
// TODO(ajcbik): rely solely on libc in future? something else?
//
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto printOp = cast<vector::PrintOp>(op);
@@ -932,7 +932,7 @@ public:
Type printType = printOp.getPrintType();
if (typeConverter.convertType(printType) == nullptr)
return matchFailure();
return failure();
// Make sure element type has runtime support (currently just Float/Double).
VectorType vectorType = printType.dyn_cast<VectorType>();
@@ -948,13 +948,13 @@ public:
else if (eltType.isF64())
printer = getPrintDouble(op);
else
return matchFailure();
return failure();
// Unroll vector into elementary print calls.
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
private:
@@ -1047,8 +1047,8 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
public:
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(StridedSliceOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(StridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getResult().getType().cast<VectorType>();
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
@@ -1089,7 +1089,7 @@ public:
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, {res});
return matchSuccess();
return success();
}
};