Files
clang-p2996/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
Alex Zinenko 1b101038dc [mlir] Turn Linalg to LLVM into a partial conversion
Historically, Linalg To LLVM conversion subsumed numerous other conversions,
including (affine) loop lowerings to CFG and conversions from the Standard and
Vector dialects to the LLVM dialect. This was due to the insufficient support
for partial conversions in the infrastructure that essentially required
conversions that involve type change (in this case, !linalg.range to
!llvm.struct) to be performed in a single conversion sweep. This is no longer
the case so remove the subsumed conversions and run them as separate passes
when necessary.

Depends On D95317

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96008
2021-02-05 14:31:19 +01:00

238 lines
8.9 KiB
C++

//===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::LLVM;
using namespace mlir::linalg;
using llvm_add = ValueBuilder<LLVM::AddOp>;
using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
using llvm_gep = ValueBuilder<LLVM::GEPOp>;
using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
using llvm_call = OperationBuilder<LLVM::CallOp>;
using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
using llvm_load = ValueBuilder<LLVM::LoadOp>;
using llvm_store = OperationBuilder<LLVM::StoreOp>;
using llvm_select = ValueBuilder<LLVM::SelectOp>;
using llvm_mul = ValueBuilder<LLVM::MulOp>;
using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
using llvm_sub = ValueBuilder<LLVM::SubOp>;
using llvm_undef = ValueBuilder<LLVM::UndefOp>;
using llvm_urem = ValueBuilder<LLVM::URemOp>;
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
template <typename T>
static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
return LLVMPointerType::get(
lowering.convertType(containerType.getElementType()));
}
/// Convert the given range descriptor type to the LLVMIR dialect.
/// Range descriptor contains the range bounds and the step as 64-bit integers.
///
/// struct {
/// int64_t min;
/// int64_t max;
/// int64_t step;
/// };
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
auto *context = t.getContext();
auto int64Ty = converter.convertType(IntegerType::get(context, 64));
return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
}
namespace {
/// EDSC-compatible wrapper for MemRefDescriptor.
class BaseViewConversionHelper {
public:
BaseViewConversionHelper(Type type)
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
BaseViewConversionHelper(Value v) : d(v) {}
/// Wrappers around MemRefDescriptor that use EDSC builder and location.
Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
Value offset() { return d.offset(rewriter(), loc()); }
void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
void setConstantSize(unsigned i, int64_t v) {
d.setConstantSize(rewriter(), loc(), i, v);
}
Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
void setConstantStride(unsigned i, int64_t v) {
d.setConstantStride(rewriter(), loc(), i, v);
}
operator Value() { return d; }
private:
OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
Location loc() { return ScopedContext::getLocation(); }
MemRefDescriptor d;
};
// RangeOp creates a new range descriptor.
class RangeOpConversion : public ConvertOpToLLVMPattern<RangeOp> {
public:
using ConvertOpToLLVMPattern<RangeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RangeOp rangeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto rangeDescriptorTy = convertRangeType(
rangeOp.getType().cast<RangeType>(), *getTypeConverter());
edsc::ScopedContext context(rewriter, rangeOp->getLoc());
// Fill in an aggregate value of the descriptor.
RangeOpAdaptor adaptor(operands);
Value desc = llvm_undef(rangeDescriptorTy);
desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
rewriter.replaceOp(rangeOp, desc);
return success();
}
};
// ReshapeOp creates a new view descriptor of the proper rank.
// For now, the only conversion supported is for target MemRef with static sizes
// and strides.
class ReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> {
public:
using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ReshapeOp reshapeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MemRefType dstType = reshapeOp.getResultType();
if (!dstType.hasStaticShape())
return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(dstType, strides, offset);
if (failed(res) || llvm::any_of(strides, [](int64_t val) {
return ShapedType::isDynamicStrideOrOffset(val);
}))
return failure();
edsc::ScopedContext context(rewriter, reshapeOp->getLoc());
ReshapeOpAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.src());
BaseViewConversionHelper desc(typeConverter->convertType(dstType));
desc.setAllocatedPtr(baseDesc.allocatedPtr());
desc.setAlignedPtr(baseDesc.alignedPtr());
desc.setOffset(baseDesc.offset());
for (auto en : llvm::enumerate(dstType.getShape()))
desc.setConstantSize(en.index(), en.value());
for (auto en : llvm::enumerate(strides))
desc.setConstantStride(en.index(), en.value());
rewriter.replaceOp(reshapeOp, {desc});
return success();
}
};
// YieldOp produces and LLVM::ReturnOp.
class YieldOpConversion : public ConvertOpToLLVMPattern<linalg::YieldOp> {
public:
using ConvertOpToLLVMPattern<linalg::YieldOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(linalg::YieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return success();
}
};
} // namespace
/// Populate the given list with patterns that convert from Linalg to LLVM.
void mlir::populateLinalgToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<RangeOpConversion, ReshapeOpConversion, YieldOpConversion>(
converter);
// Populate the type conversions for the linalg types.
converter.addConversion(
[&](RangeType type) { return convertRangeType(type, converter); });
}
namespace {
struct ConvertLinalgToLLVMPass
: public ConvertLinalgToLLVMBase<ConvertLinalgToLLVMPass> {
void runOnOperation() override;
};
} // namespace
void ConvertLinalgToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
populateLinalgToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addIllegalOp<RangeOp>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp, LLVM::DialectCastOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
return std::make_unique<ConvertLinalgToLLVMPass>();
}