Files
clang-p2996/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
Krzysztof Drewniak df852599f3 [mlir] Split up VectorToLLVM pass
Currently, the VectorToLLVM patterns are built into a library along
with the corresponding pass, which also pulls in all the
platform-specific vector dialects (like AMXDialect) to apply all the
vector to LLVM conversions.

This causes dependency bloat when writing libraries - for example the
GPU to LLVM passes, which use the vector to LLVM patterns, don't need
the X86Vector dialect to be present at all.

This commit partitions the library into VectorToLLVM and
VectorToLLVMPass, where the latter pulls in all the other vector
transformations.

Reviewed By: nicolasvasilache, mehdi_amini

Differential Revision: https://reviews.llvm.org/D158287
2023-09-13 16:09:56 +00:00

129 lines
5.1 KiB
C++

//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::vector;
namespace {
struct LowerVectorToLLVMPass
: public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> {
using Base::Base;
// Override explicitly to allow conditional dialect dependence.
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<memref::MemRefDialect>();
if (armNeon)
registry.insert<arm_neon::ArmNeonDialect>();
if (armSVE)
registry.insert<arm_sve::ArmSVEDialect>();
if (armSME)
registry.insert<arm_sme::ArmSMEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86Vector)
registry.insert<x86vector::X86VectorDialect>();
}
void runOnOperation() override;
};
} // namespace
void LowerVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns,
VectorTransformsOptions());
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
// Convert to the LLVM IR dialect.
LowerToLLVMOptions options(&getContext());
options.useOpaquePointers = useOpaquePointers;
LLVMTypeConverter converter(&getContext(), options);
RewritePatternSet patterns(&getContext());
populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, force32BitVectorIndices);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
if (armNeon) {
// TODO: we may or may not want to include in-dialect lowering to
// LLVM-compatible operations here. So far, all operations in the dialect
// can be translated to LLVM IR so there is no conversion necessary.
target.addLegalDialect<arm_neon::ArmNeonDialect>();
}
if (armSVE) {
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (armSME) {
configureArmSMELegalizeForExportTarget(target);
populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
}
if (amx) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
}
if (x86Vector) {
configureX86VectorLegalizeForExportTarget(target);
populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
}
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}