At the moment, SME-to-LLVM lowerings rely entirely on `LLVMTypeConverter`. This patch introduces a dedicated `TypeConverter` that inherits from `LLVMTypeConverter` (it will also be used when lowering ArmSME Ops to LLVM). The new type converter merely disables lowerings for `VectorType` to prevent 2-d scalable vectors (common in the context of ArmSME), e.g. `vector<[16]x[16]xi8>`, entering the LLVM Type converter. LLVM does not support arrays of scalable vectors and hence the need for specialisation. In the case of SME such types are effectively eliminated when emitting LLVM IR intrinsics for SME. Differential Revision: https://reviews.llvm.org/D155365
129 lines
5.1 KiB
C++
129 lines
5.1 KiB
C++
//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
|
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Dialect/AMX/AMXDialect.h"
|
|
#include "mlir/Dialect/AMX/Transforms.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
|
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
|
|
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
|
|
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
|
#include "mlir/Dialect/ArmSVE/Transforms.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
|
#include "mlir/Dialect/X86Vector/Transforms.h"
|
|
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::vector;
|
|
|
|
namespace {
|
|
struct LowerVectorToLLVMPass
|
|
: public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> {
|
|
|
|
using Base::Base;
|
|
|
|
// Override explicitly to allow conditional dialect dependence.
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<LLVM::LLVMDialect>();
|
|
registry.insert<arith::ArithDialect>();
|
|
registry.insert<memref::MemRefDialect>();
|
|
if (armNeon)
|
|
registry.insert<arm_neon::ArmNeonDialect>();
|
|
if (armSVE)
|
|
registry.insert<arm_sve::ArmSVEDialect>();
|
|
if (armSME)
|
|
registry.insert<arm_sme::ArmSMEDialect>();
|
|
if (amx)
|
|
registry.insert<amx::AMXDialect>();
|
|
if (x86Vector)
|
|
registry.insert<x86vector::X86VectorDialect>();
|
|
}
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void LowerVectorToLLVMPass::runOnOperation() {
|
|
// Perform progressive lowering of operations on slices and
|
|
// all contraction operations. Also applies folding and DCE.
|
|
{
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorToVectorCanonicalizationPatterns(patterns);
|
|
populateVectorBroadcastLoweringPatterns(patterns);
|
|
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
|
|
populateVectorMaskOpLoweringPatterns(patterns);
|
|
populateVectorShapeCastLoweringPatterns(patterns);
|
|
populateVectorTransposeLoweringPatterns(patterns,
|
|
VectorTransformsOptions());
|
|
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
|
|
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
// Convert to the LLVM IR dialect.
|
|
LowerToLLVMOptions options(&getContext());
|
|
options.useOpaquePointers = useOpaquePointers;
|
|
LLVMTypeConverter converter(&getContext(), options);
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
|
|
populateVectorTransferLoweringPatterns(patterns);
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
populateVectorToLLVMConversionPatterns(
|
|
converter, patterns, reassociateFPReductions, force32BitVectorIndices);
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
|
|
// Architecture specific augmentations.
|
|
LLVMConversionTarget target(getContext());
|
|
target.addLegalDialect<arith::ArithDialect>();
|
|
target.addLegalDialect<memref::MemRefDialect>();
|
|
target.addLegalOp<UnrealizedConversionCastOp>();
|
|
arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
|
|
|
|
if (armNeon) {
|
|
// TODO: we may or may not want to include in-dialect lowering to
|
|
// LLVM-compatible operations here. So far, all operations in the dialect
|
|
// can be translated to LLVM IR so there is no conversion necessary.
|
|
target.addLegalDialect<arm_neon::ArmNeonDialect>();
|
|
}
|
|
if (armSVE) {
|
|
configureArmSVELegalizeForExportTarget(target);
|
|
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
|
|
}
|
|
if (armSME) {
|
|
configureArmSMELegalizeForExportTarget(target);
|
|
populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
|
|
}
|
|
if (amx) {
|
|
configureAMXLegalizeForExportTarget(target);
|
|
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
|
|
}
|
|
if (x86Vector) {
|
|
configureX86VectorLegalizeForExportTarget(target);
|
|
populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
|
|
}
|
|
|
|
if (failed(
|
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|