This revision adds support for generating utilities for passes such as options/statistics/etc. that can be inferred from the tablegen definition. This removes additional boilerplate from the pass, and also makes it easier to remove the reliance on the pass registry to provide certain things(e.g. the pass argument). Differential Revision: https://reviews.llvm.org/D76659
198 lines
7.9 KiB
C++
198 lines
7.9 KiB
C++
//===- ConvertAVX512ToLLVM.cpp - Convert AVX512 to the LLVM dialect -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
|
|
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
|
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/EDSC/Intrinsics.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::edsc;
|
|
using namespace mlir::edsc::intrinsics;
|
|
using namespace mlir::vector;
|
|
using namespace mlir::avx512;
|
|
|
|
template <typename OpTy>
|
|
static Type getSrcVectorElementType(OpTy op) {
|
|
return op.src().getType().template cast<VectorType>().getElementType();
|
|
}
|
|
|
|
// TODO(ntv, zinenko): Code is currently copy-pasted and adapted from the code
|
|
// 1-1 LLVM conversion. It would better if it were properly exposed in core and
|
|
// reusable.
|
|
/// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to
|
|
/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass
|
|
/// operands as is, preserve attributes.
|
|
template <typename SourceOp, typename TargetOp>
|
|
static LogicalResult
|
|
matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering,
|
|
LLVMTypeConverter &typeConverter, Operation *op,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) {
|
|
unsigned numResults = op->getNumResults();
|
|
|
|
Type packedType;
|
|
if (numResults != 0) {
|
|
packedType = typeConverter.packFunctionResults(op->getResultTypes());
|
|
if (!packedType)
|
|
return failure();
|
|
}
|
|
|
|
auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
|
|
op->getAttrs());
|
|
|
|
// If the operation produced 0 or 1 result, return them immediately.
|
|
if (numResults == 0)
|
|
return rewriter.eraseOp(op), success();
|
|
if (numResults == 1)
|
|
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
|
|
success();
|
|
|
|
// Otherwise, it had been converted to an operation producing a structure.
|
|
// Extract individual results from the structure and return them as list.
|
|
SmallVector<Value, 4> results;
|
|
results.reserve(numResults);
|
|
for (unsigned i = 0; i < numResults; ++i) {
|
|
auto type = typeConverter.convertType(op->getResult(i).getType());
|
|
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
|
op->getLoc(), type, newOp.getOperation()->getResult(0),
|
|
rewriter.getI64ArrayAttr(i)));
|
|
}
|
|
rewriter.replaceOp(op, results);
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
// TODO(ntv): Patterns are too verbose due to the fact that we have 1 op (e.g.
|
|
// MaskRndScaleOp) and different possible target ops. It would be better to take
|
|
// a Functor so that all these conversions become 1-liners.
|
|
struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern {
|
|
explicit MaskRndScaleOpPS512Conversion(MLIRContext *context,
|
|
LLVMTypeConverter &typeConverter)
|
|
: ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
|
|
typeConverter) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF32())
|
|
return failure();
|
|
return matchAndRewriteOneToOne<MaskRndScaleOp,
|
|
LLVM::x86_avx512_mask_rndscale_ps_512>(
|
|
*this, this->typeConverter, op, operands, rewriter);
|
|
}
|
|
};
|
|
|
|
struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern {
|
|
explicit MaskRndScaleOpPD512Conversion(MLIRContext *context,
|
|
LLVMTypeConverter &typeConverter)
|
|
: ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
|
|
typeConverter) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF64())
|
|
return failure();
|
|
return matchAndRewriteOneToOne<MaskRndScaleOp,
|
|
LLVM::x86_avx512_mask_rndscale_pd_512>(
|
|
*this, this->typeConverter, op, operands, rewriter);
|
|
}
|
|
};
|
|
|
|
struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern {
|
|
explicit ScaleFOpPS512Conversion(MLIRContext *context,
|
|
LLVMTypeConverter &typeConverter)
|
|
: ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
|
|
typeConverter) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF32())
|
|
return failure();
|
|
return matchAndRewriteOneToOne<MaskScaleFOp,
|
|
LLVM::x86_avx512_mask_scalef_ps_512>(
|
|
*this, this->typeConverter, op, operands, rewriter);
|
|
}
|
|
};
|
|
|
|
struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
|
|
explicit ScaleFOpPD512Conversion(MLIRContext *context,
|
|
LLVMTypeConverter &typeConverter)
|
|
: ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
|
|
typeConverter) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF64())
|
|
return failure();
|
|
return matchAndRewriteOneToOne<MaskScaleFOp,
|
|
LLVM::x86_avx512_mask_scalef_pd_512>(
|
|
*this, this->typeConverter, op, operands, rewriter);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
/// Populate the given list with patterns that convert from AVX512 to LLVM.
|
|
void mlir::populateAVX512ToLLVMConversionPatterns(
|
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
|
MLIRContext *ctx = converter.getDialect()->getContext();
|
|
// clang-format off
|
|
patterns.insert<MaskRndScaleOpPS512Conversion,
|
|
MaskRndScaleOpPD512Conversion,
|
|
ScaleFOpPS512Conversion,
|
|
ScaleFOpPD512Conversion>(ctx, converter);
|
|
// clang-format on
|
|
}
|
|
|
|
namespace {
|
|
struct ConvertAVX512ToLLVMPass : public ModulePass<ConvertAVX512ToLLVMPass> {
|
|
/// Include the generated pass utilities.
|
|
#define GEN_PASS_ConvertAVX512ToLLVM
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
|
|
void runOnModule() override;
|
|
};
|
|
} // namespace
|
|
|
|
void ConvertAVX512ToLLVMPass::runOnModule() {
|
|
// Convert to the LLVM IR dialect.
|
|
OwningRewritePatternList patterns;
|
|
LLVMTypeConverter converter(&getContext());
|
|
populateAVX512ToLLVMConversionPatterns(converter, patterns);
|
|
populateVectorToLLVMConversionPatterns(converter, patterns);
|
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
|
|
|
ConversionTarget target(getContext());
|
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
|
target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
|
|
target.addIllegalDialect<avx512::AVX512Dialect>();
|
|
target.addDynamicallyLegalOp<FuncOp>(
|
|
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
|
if (failed(
|
|
applyPartialConversion(getModule(), target, patterns, &converter))) {
|
|
signalPassFailure();
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertAVX512ToLLVMPass() {
|
|
return std::make_unique<ConvertAVX512ToLLVMPass>();
|
|
}
|