[mlir][Vector] Add a vector.matrix_multiply op on 1-D vectors

Summary: This op mirrors the llvm.intr counterpart and allows lowering + type conversions in a progressive fashion.

Differential Revision: https://reviews.llvm.org/D75775
This commit is contained in:
Nicolas Vasilache
2020-03-09 13:29:13 -04:00
parent 47caa69120
commit 63b683a816
6 changed files with 111 additions and 3 deletions

View File

@@ -275,6 +275,28 @@ private:
}
};
/// Conversion pattern for a vector.matrix_multiply.
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorMatmulOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto matmulOp = cast<vector::MatmulOp>(op);
auto adaptor = vector::MatmulOpOperandAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
matmulOp.rhs_columns());
return matchSuccess();
}
};
class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorReductionOpConversion(MLIRContext *context,
@@ -1141,6 +1163,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorPrintOpConversion>(ctx, converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.insert<VectorMatmulOpConversion>(ctx, converter);
}
namespace {
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
void runOnModule() override;
@@ -1160,6 +1188,7 @@ void LowerVectorToLLVMPass::runOnModule() {
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);