[mlir][Vector] Add a vector.matrix_multiply op on 1-D vectors
Summary: This op mirrors the llvm.intr counterpart and allows lowering + type conversions in a progressive fashion. Differential Revision: https://reviews.llvm.org/D75775
This commit is contained in:
@@ -275,6 +275,28 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.matrix_multiply.
|
||||
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
|
||||
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorMatmulOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto matmulOp = cast<vector::MatmulOp>(op);
|
||||
auto adaptor = vector::MatmulOpOperandAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
|
||||
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
|
||||
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
|
||||
matmulOp.rhs_columns());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorReductionOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorReductionOpConversion(MLIRContext *context,
|
||||
@@ -1141,6 +1163,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorPrintOpConversion>(ctx, converter);
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.insert<VectorMatmulOpConversion>(ctx, converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
|
||||
void runOnModule() override;
|
||||
@@ -1160,6 +1188,7 @@ void LowerVectorToLLVMPass::runOnModule() {
|
||||
// Convert to the LLVM IR dialect.
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
||||
populateVectorToLLVMConversionPatterns(converter, patterns);
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user