//===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "flang/Optimizer/CodeGen/CodeGen.h" #include "PassDetail.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/ArrayRef.h" #define DEBUG_TYPE "flang-codegen" // fir::LLVMTypeConverter for converting to LLVM IR dialect types. #include "TypeConverter.h" namespace { /// FIR conversion pattern template template class FIROpConversion : public mlir::ConvertOpToLLVMPattern { public: explicit FIROpConversion(fir::LLVMTypeConverter &lowering) : mlir::ConvertOpToLLVMPattern(lowering) {} protected: mlir::Type convertType(mlir::Type ty) const { return lowerTy().convertType(ty); } fir::LLVMTypeConverter &lowerTy() const { return *static_cast(this->getTypeConverter()); } }; } // namespace namespace { struct AddrOfOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto ty = convertType(addr.getType()); rewriter.replaceOpWithNewOp( addr, ty, addr.symbol().getRootReference().getValue()); return success(); } }; struct HasValueOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; struct GlobalOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto tyAttr = convertType(global.getType()); if (global.getType().isa()) tyAttr = tyAttr.cast().getElementType(); auto loc = global.getLoc(); mlir::Attribute initAttr{}; if (global.initVal()) initAttr = global.initVal().getValue(); auto linkage = convertLinkage(global.linkName()); auto isConst = global.constant().hasValue(); auto g = rewriter.create( loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); auto &gr = g.getInitializerRegion(); rewriter.inlineRegionBefore(global.region(), gr, gr.end()); if (!gr.empty()) { // Replace insert_on_range with a constant dense attribute if the // initialization is on the full range. auto insertOnRangeOps = gr.front().getOps(); for (auto insertOp : insertOnRangeOps) { if (isFullRange(insertOp.coor(), insertOp.getType())) { auto seqTyAttr = convertType(insertOp.getType()); auto *op = insertOp.val().getDefiningOp(); auto constant = mlir::dyn_cast(op); if (!constant) { auto convertOp = mlir::dyn_cast(op); if (!convertOp) continue; constant = cast( convertOp.value().getDefiningOp()); } mlir::Type vecType = mlir::VectorType::get( insertOp.getType().getShape(), constant.getType()); auto denseAttr = mlir::DenseElementsAttr::get( vecType.cast(), constant.value()); rewriter.setInsertionPointAfter(insertOp); rewriter.replaceOpWithNewOp( insertOp, seqTyAttr, denseAttr); } } } rewriter.eraseOp(global); return success(); } bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { auto extents = seqTy.getShape(); if (indexes.size() / 2 != extents.size()) return false; for (unsigned i = 0; i < indexes.size(); i += 2) { if (indexes[i].cast().getInt() != 0) return false; if (indexes[i + 1].cast().getInt() != extents[i / 2] - 1) return false; } return true; } mlir::LLVM::Linkage convertLinkage(Optional optLinkage) const { if (optLinkage.hasValue()) { auto name = optLinkage.getValue(); if (name == "internal") return mlir::LLVM::Linkage::Internal; if (name == "linkonce") return mlir::LLVM::Linkage::Linkonce; if (name == "common") return mlir::LLVM::Linkage::Common; if (name == "weak") return mlir::LLVM::Linkage::Weak; } return mlir::LLVM::Linkage::External; } }; // convert to LLVM IR dialect `undef` struct UndefOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult matchAndRewrite(fir::UndefOp undef, OpAdaptor, mlir::ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( undef, convertType(undef.getType())); return success(); } }; } // namespace namespace { /// Convert FIR dialect to LLVM dialect /// /// This pass lowers all FIR dialect operations to LLVM IR dialect. An /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. /// /// This pass is not complete yet. We are upstreaming it in small patches. class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase { public: mlir::ModuleOp getModule() { return getOperation(); } void runOnOperation() override final { auto *context = getModule().getContext(); fir::LLVMTypeConverter typeConverter{getModule()}; auto loc = mlir::UnknownLoc::get(context); mlir::OwningRewritePatternList pattern(context); pattern.insert(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); mlir::ConversionTarget target{*context}; target.addLegalDialect(); // required NOPs for applying a full conversion target.addLegalOp(); // apply the patterns if (mlir::failed(mlir::applyFullConversion(getModule(), target, std::move(pattern)))) { mlir::emitError(loc, "error in converting to LLVM-IR dialect\n"); signalPassFailure(); } } }; } // namespace std::unique_ptr fir::createFIRToLLVMPass() { return std::make_unique(); }