//===- ConvertAVX512ToLLVM.cpp - Convert AVX512 to the LLVM dialect -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::vector; using namespace mlir::avx512; template static Type getSrcVectorElementType(OpTy op) { return op.src().getType().template cast().getElementType(); } // TODO(ntv, zinenko): Code is currently copy-pasted and adapted from the code // 1-1 LLVM conversion. It would better if it were properly exposed in core and // reusable. /// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to /// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass /// operands as is, preserve attributes. template static LogicalResult matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering, LLVMTypeConverter &typeConverter, Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); Type packedType; if (numResults != 0) { packedType = typeConverter.packFunctionResults(op->getResultTypes()); if (!packedType) return failure(); } auto newOp = rewriter.create(op->getLoc(), packedType, operands, op->getAttrs()); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) return rewriter.eraseOp(op), success(); if (numResults == 1) return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), success(); // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), rewriter.getI64ArrayAttr(i))); } rewriter.replaceOp(op, results); return success(); } namespace { // TODO(ntv): Patterns are too verbose due to the fact that we have 1 op (e.g. // MaskRndScaleOp) and different possible target ops. It would be better to take // a Functor so that all these conversions become 1-liners. struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern { explicit MaskRndScaleOpPS512Conversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!getSrcVectorElementType(cast(op)).isF32()) return failure(); return matchAndRewriteOneToOne( *this, this->typeConverter, op, operands, rewriter); } }; struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern { explicit MaskRndScaleOpPD512Conversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!getSrcVectorElementType(cast(op)).isF64()) return failure(); return matchAndRewriteOneToOne( *this, this->typeConverter, op, operands, rewriter); } }; struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern { explicit ScaleFOpPS512Conversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!getSrcVectorElementType(cast(op)).isF32()) return failure(); return matchAndRewriteOneToOne( *this, this->typeConverter, op, operands, rewriter); } }; struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern { explicit ScaleFOpPD512Conversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!getSrcVectorElementType(cast(op)).isF64()) return failure(); return matchAndRewriteOneToOne( *this, this->typeConverter, op, operands, rewriter); } }; } // namespace /// Populate the given list with patterns that convert from AVX512 to LLVM. void mlir::populateAVX512ToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); // clang-format off patterns.insert(ctx, converter); // clang-format on } namespace { struct ConvertAVX512ToLLVMPass : public ModulePass { /// Include the generated pass utilities. #define GEN_PASS_ConvertAVX512ToLLVM #include "mlir/Conversion/Passes.h.inc" void runOnModule() override; }; } // namespace void ConvertAVX512ToLLVMPass::runOnModule() { // Convert to the LLVM IR dialect. OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); populateAVX512ToLLVMConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalDialect(); target.addIllegalDialect(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); if (failed( applyPartialConversion(getModule(), target, patterns, &converter))) { signalPassFailure(); } } std::unique_ptr> mlir::createConvertAVX512ToLLVMPass() { return std::make_unique(); }