Files
clang-p2996/flang/lib/Optimizer/CodeGen/CodeGen.cpp
Valentin Clement 044d5b5dd1 [fir] Add base of the FIR to LLVM IR pass
This patch adds the base of the FIR to LLVM IR Dialect conversion pass.
It currently can convert the following operations:
 - fir.global
 - fir.has_value
 - fir.address_of
 - fir.undefined

This patch is part of the upstreaming effort from fir-dev branch. It does not
cover all FIR operations in order to have small patches. Several patches will
follow to convert other operations.

Reviewed By: schweitz

Differential Revision: https://reviews.llvm.org/D112845
2021-10-29 23:21:43 +02:00

206 lines
7.5 KiB
C++

//===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/CodeGen/CodeGen.h"
#include "PassDetail.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/ArrayRef.h"
#define DEBUG_TYPE "flang-codegen"
// fir::LLVMTypeConverter for converting to LLVM IR dialect types.
#include "TypeConverter.h"
namespace {
/// FIR conversion pattern template
template <typename FromOp>
class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
public:
explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
: mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
protected:
mlir::Type convertType(mlir::Type ty) const {
return lowerTy().convertType(ty);
}
fir::LLVMTypeConverter &lowerTy() const {
return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
}
};
} // namespace
namespace {
struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto ty = convertType(addr.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
addr, ty, addr.symbol().getRootReference().getValue());
return success();
}
};
struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
return success();
}
};
struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto tyAttr = convertType(global.getType());
if (global.getType().isa<fir::BoxType>())
tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
auto loc = global.getLoc();
mlir::Attribute initAttr{};
if (global.initVal())
initAttr = global.initVal().getValue();
auto linkage = convertLinkage(global.linkName());
auto isConst = global.constant().hasValue();
auto g = rewriter.create<mlir::LLVM::GlobalOp>(
loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
auto &gr = g.getInitializerRegion();
rewriter.inlineRegionBefore(global.region(), gr, gr.end());
if (!gr.empty()) {
// Replace insert_on_range with a constant dense attribute if the
// initialization is on the full range.
auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
for (auto insertOp : insertOnRangeOps) {
if (isFullRange(insertOp.coor(), insertOp.getType())) {
auto seqTyAttr = convertType(insertOp.getType());
auto *op = insertOp.val().getDefiningOp();
auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
if (!constant) {
auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
if (!convertOp)
continue;
constant = cast<mlir::arith::ConstantOp>(
convertOp.value().getDefiningOp());
}
mlir::Type vecType = mlir::VectorType::get(
insertOp.getType().getShape(), constant.getType());
auto denseAttr = mlir::DenseElementsAttr::get(
vecType.cast<ShapedType>(), constant.value());
rewriter.setInsertionPointAfter(insertOp);
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
insertOp, seqTyAttr, denseAttr);
}
}
}
rewriter.eraseOp(global);
return success();
}
bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
auto extents = seqTy.getShape();
if (indexes.size() / 2 != extents.size())
return false;
for (unsigned i = 0; i < indexes.size(); i += 2) {
if (indexes[i].cast<IntegerAttr>().getInt() != 0)
return false;
if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
return false;
}
return true;
}
mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
if (optLinkage.hasValue()) {
auto name = optLinkage.getValue();
if (name == "internal")
return mlir::LLVM::Linkage::Internal;
if (name == "linkonce")
return mlir::LLVM::Linkage::Linkonce;
if (name == "common")
return mlir::LLVM::Linkage::Common;
if (name == "weak")
return mlir::LLVM::Linkage::Weak;
}
return mlir::LLVM::Linkage::External;
}
};
// convert to LLVM IR dialect `undef`
struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::UndefOp undef, OpAdaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
undef, convertType(undef.getType()));
return success();
}
};
} // namespace
namespace {
/// Convert FIR dialect to LLVM dialect
///
/// This pass lowers all FIR dialect operations to LLVM IR dialect. An
/// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
///
/// This pass is not complete yet. We are upstreaming it in small patches.
class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
public:
mlir::ModuleOp getModule() { return getOperation(); }
void runOnOperation() override final {
auto *context = getModule().getContext();
fir::LLVMTypeConverter typeConverter{getModule()};
auto loc = mlir::UnknownLoc::get(context);
mlir::OwningRewritePatternList pattern(context);
pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
UndefOpConversion>(typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);
mlir::ConversionTarget target{*context};
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
// required NOPs for applying a full conversion
target.addLegalOp<mlir::ModuleOp>();
// apply the patterns
if (mlir::failed(mlir::applyFullConversion(getModule(), target,
std::move(pattern)))) {
mlir::emitError(loc, "error in converting to LLVM-IR dialect\n");
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
return std::make_unique<FIRToLLVMLowering>();
}