A separate AVX512 lowering pass does not compose well with the regular vector lowering pass. As such, it is at risk of code duplication and lowering inconsistencies. This change removes the separate AVX512 lowering pass and makes it an "option" in the regular vector lowering pass (viz. vector dialect "augmented" with AVX512 dialect). Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D92614
74 lines
2.8 KiB
C++
74 lines
2.8 KiB
C++
//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
|
|
|
#include "../PassDetail.h"
|
|
|
|
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::vector;
|
|
|
|
namespace {
|
|
struct LowerVectorToLLVMPass
|
|
: public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
|
|
LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
|
|
this->reassociateFPReductions = options.reassociateFPReductions;
|
|
this->enableIndexOptimizations = options.enableIndexOptimizations;
|
|
this->enableAVX512 = options.enableAVX512;
|
|
}
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void LowerVectorToLLVMPass::runOnOperation() {
|
|
// Perform progressive lowering of operations on slices and
|
|
// all contraction operations. Also applies folding and DCE.
|
|
{
|
|
OwningRewritePatternList patterns;
|
|
populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
|
|
populateVectorSlicesLoweringPatterns(patterns, &getContext());
|
|
populateVectorContractLoweringPatterns(patterns, &getContext());
|
|
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
// Convert to the LLVM IR dialect.
|
|
LLVMTypeConverter converter(&getContext());
|
|
OwningRewritePatternList patterns;
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
populateVectorToLLVMConversionPatterns(
|
|
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
|
|
|
// Architecture specific augmentations.
|
|
LLVMConversionTarget target(getContext());
|
|
if (enableAVX512) {
|
|
target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
|
|
target.addIllegalDialect<avx512::AVX512Dialect>();
|
|
populateAVX512ToLLVMConversionPatterns(converter, patterns);
|
|
}
|
|
|
|
if (failed(
|
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
|
|
return std::make_unique<LowerVectorToLLVMPass>(options);
|
|
}
|