This revision adds support for generating utilities for passes such as options/statistics/etc. that can be inferred from the tablegen definition. This removes additional boilerplate from the pass, and also makes it easier to remove the reliance on the pass registry to provide certain things(e.g. the pass argument). Differential Revision: https://reviews.llvm.org/D76659
594 lines
23 KiB
C++
594 lines
23 KiB
C++
//===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
|
|
|
|
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
#include "mlir/IR/Types.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Support/Allocator.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::edsc;
|
|
using namespace mlir::edsc::intrinsics;
|
|
using namespace mlir::LLVM;
|
|
using namespace mlir::linalg;
|
|
|
|
using llvm_add = ValueBuilder<LLVM::AddOp>;
|
|
using llvm_bitcast = ValueBuilder<LLVM::BitcastOp>;
|
|
using llvm_constant = ValueBuilder<LLVM::ConstantOp>;
|
|
using llvm_extractvalue = ValueBuilder<LLVM::ExtractValueOp>;
|
|
using llvm_gep = ValueBuilder<LLVM::GEPOp>;
|
|
using llvm_insertvalue = ValueBuilder<LLVM::InsertValueOp>;
|
|
using llvm_call = OperationBuilder<LLVM::CallOp>;
|
|
using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
|
|
using llvm_load = ValueBuilder<LLVM::LoadOp>;
|
|
using llvm_store = OperationBuilder<LLVM::StoreOp>;
|
|
using llvm_select = ValueBuilder<LLVM::SelectOp>;
|
|
using llvm_mul = ValueBuilder<LLVM::MulOp>;
|
|
using llvm_ptrtoint = ValueBuilder<LLVM::PtrToIntOp>;
|
|
using llvm_sub = ValueBuilder<LLVM::SubOp>;
|
|
using llvm_undef = ValueBuilder<LLVM::UndefOp>;
|
|
using llvm_urem = ValueBuilder<LLVM::URemOp>;
|
|
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
|
|
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
|
|
|
|
template <typename T>
|
|
static LLVMType getPtrToElementType(T containerType,
|
|
LLVMTypeConverter &lowering) {
|
|
return lowering.convertType(containerType.getElementType())
|
|
.template cast<LLVMType>()
|
|
.getPointerTo();
|
|
}
|
|
|
|
/// Convert the given range descriptor type to the LLVMIR dialect.
|
|
/// Range descriptor contains the range bounds and the step as 64-bit integers.
|
|
///
|
|
/// struct {
|
|
/// int64_t min;
|
|
/// int64_t max;
|
|
/// int64_t step;
|
|
/// };
|
|
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
|
|
auto *context = t.getContext();
|
|
auto int64Ty = converter.convertType(IntegerType::get(64, context))
|
|
.cast<LLVM::LLVMType>();
|
|
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
|
}
|
|
|
|
namespace {
|
|
/// EDSC-compatible wrapper for MemRefDescriptor.
|
|
class BaseViewConversionHelper {
|
|
public:
|
|
BaseViewConversionHelper(Type type)
|
|
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
|
|
|
|
BaseViewConversionHelper(Value v) : d(v) {}
|
|
|
|
/// Wrappers around MemRefDescriptor that use EDSC builder and location.
|
|
Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
|
|
void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
|
|
Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
|
|
void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
|
|
Value offset() { return d.offset(rewriter(), loc()); }
|
|
void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
|
|
Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
|
|
void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
|
|
void setConstantSize(unsigned i, int64_t v) {
|
|
d.setConstantSize(rewriter(), loc(), i, v);
|
|
}
|
|
Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
|
|
void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
|
|
void setConstantStride(unsigned i, int64_t v) {
|
|
d.setConstantStride(rewriter(), loc(), i, v);
|
|
}
|
|
|
|
operator Value() { return d; }
|
|
|
|
private:
|
|
OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
|
|
Location loc() { return ScopedContext::getLocation(); }
|
|
|
|
MemRefDescriptor d;
|
|
};
|
|
|
|
// RangeOp creates a new range descriptor.
|
|
class RangeOpConversion : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto rangeOp = cast<RangeOp>(op);
|
|
auto rangeDescriptorTy =
|
|
convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter);
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
|
|
// Fill in an aggregate value of the descriptor.
|
|
RangeOpOperandAdaptor adaptor(operands);
|
|
Value desc = llvm_undef(rangeDescriptorTy);
|
|
desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
|
|
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
|
|
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
|
|
rewriter.replaceOp(op, desc);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// ReshapeOp creates a new view descriptor of the proper rank.
|
|
// For now, the only conversion supported is for target MemRef with static sizes
|
|
// and strides.
|
|
class ReshapeOpConversion : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit ReshapeOpConversion(MLIRContext *context,
|
|
LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
|
|
lowering_) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto reshapeOp = cast<ReshapeOp>(op);
|
|
MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
|
|
|
|
if (!dstType.hasStaticShape())
|
|
return failure();
|
|
|
|
int64_t offset;
|
|
SmallVector<int64_t, 4> strides;
|
|
auto res = getStridesAndOffset(dstType, strides, offset);
|
|
if (failed(res) || llvm::any_of(strides, [](int64_t val) {
|
|
return ShapedType::isDynamicStrideOrOffset(val);
|
|
}))
|
|
return failure();
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
ReshapeOpOperandAdaptor adaptor(operands);
|
|
BaseViewConversionHelper baseDesc(adaptor.view());
|
|
BaseViewConversionHelper desc(typeConverter.convertType(dstType));
|
|
desc.setAllocatedPtr(baseDesc.allocatedPtr());
|
|
desc.setAlignedPtr(baseDesc.alignedPtr());
|
|
desc.setOffset(baseDesc.offset());
|
|
for (auto en : llvm::enumerate(dstType.getShape()))
|
|
desc.setConstantSize(en.index(), en.value());
|
|
for (auto en : llvm::enumerate(strides))
|
|
desc.setConstantStride(en.index(), en.value());
|
|
rewriter.replaceOp(op, {desc});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern that transforms a linalg.slice op into:
|
|
/// 1. An "undef" value for the ViewDescriptor.
|
|
/// 2. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
|
/// and stride corresponding to the region of memory within the bounds of
|
|
/// the parent view.
|
|
/// The linalg.slice op is replaced by the alloca'ed pointer.
|
|
class SliceOpConversion : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
SliceOpOperandAdaptor adaptor(operands);
|
|
BaseViewConversionHelper baseDesc(adaptor.view());
|
|
|
|
auto sliceOp = cast<SliceOp>(op);
|
|
auto memRefType = sliceOp.getBaseViewType();
|
|
auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64))
|
|
.cast<LLVM::LLVMType>();
|
|
|
|
BaseViewConversionHelper desc(
|
|
typeConverter.convertType(sliceOp.getShapedType()));
|
|
|
|
// TODO(ntv): extract sizes and emit asserts.
|
|
SmallVector<Value, 4> strides(memRefType.getRank());
|
|
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
|
|
strides[i] = baseDesc.stride(i);
|
|
|
|
auto pos = [&rewriter](ArrayRef<int64_t> values) {
|
|
return rewriter.getI64ArrayAttr(values);
|
|
};
|
|
|
|
// Compute base offset.
|
|
Value baseOffset = baseDesc.offset();
|
|
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
|
|
Value indexing = adaptor.indexings()[i];
|
|
Value min = indexing;
|
|
if (sliceOp.indexing(i).getType().isa<RangeType>())
|
|
min = llvm_extractvalue(int64Ty, indexing, pos(0));
|
|
baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i]));
|
|
}
|
|
|
|
// Insert the base and aligned pointers.
|
|
desc.setAllocatedPtr(baseDesc.allocatedPtr());
|
|
desc.setAlignedPtr(baseDesc.alignedPtr());
|
|
|
|
// Insert base offset.
|
|
desc.setOffset(baseOffset);
|
|
|
|
// Corner case, no sizes or strides: early return the descriptor.
|
|
if (sliceOp.getShapedType().getRank() == 0)
|
|
return rewriter.replaceOp(op, {desc}), success();
|
|
|
|
Value zero = llvm_constant(
|
|
int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
|
// Compute and insert view sizes (max - min along the range) and strides.
|
|
// Skip the non-range operands as they will be projected away from the view.
|
|
int numNewDims = 0;
|
|
for (auto en : llvm::enumerate(sliceOp.indexings())) {
|
|
Value indexing = en.value();
|
|
if (indexing.getType().isa<RangeType>()) {
|
|
int rank = en.index();
|
|
Value rangeDescriptor = adaptor.indexings()[rank];
|
|
Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0));
|
|
Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1));
|
|
Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2));
|
|
Value baseSize = baseDesc.size(rank);
|
|
|
|
// Bound upper by base view upper bound.
|
|
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
|
|
baseSize);
|
|
Value size = llvm_sub(max, min);
|
|
// Bound lower by zero.
|
|
size =
|
|
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
|
|
Value stride = llvm_mul(strides[rank], step);
|
|
desc.setSize(numNewDims, size);
|
|
desc.setStride(numNewDims, stride);
|
|
++numNewDims;
|
|
}
|
|
}
|
|
|
|
rewriter.replaceOp(op, {desc});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern that transforms a linalg.transpose op into:
|
|
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
|
|
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
|
|
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
|
/// and stride. Size and stride are permutations of the original values.
|
|
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
|
/// The linalg.transpose op is replaced by the alloca'ed pointer.
|
|
class TransposeOpConversion : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit TransposeOpConversion(MLIRContext *context,
|
|
LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(TransposeOp::getOperationName(), context,
|
|
lowering_) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Initialize the common boilerplate and alloca at the top of the FuncOp.
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
TransposeOpOperandAdaptor adaptor(operands);
|
|
BaseViewConversionHelper baseDesc(adaptor.view());
|
|
|
|
auto transposeOp = cast<TransposeOp>(op);
|
|
// No permutation, early exit.
|
|
if (transposeOp.permutation().isIdentity())
|
|
return rewriter.replaceOp(op, {baseDesc}), success();
|
|
|
|
BaseViewConversionHelper desc(
|
|
typeConverter.convertType(transposeOp.getShapedType()));
|
|
|
|
// Copy the base and aligned pointers from the old descriptor to the new
|
|
// one.
|
|
desc.setAllocatedPtr(baseDesc.allocatedPtr());
|
|
desc.setAlignedPtr(baseDesc.alignedPtr());
|
|
|
|
// Copy the offset pointer from the old descriptor to the new one.
|
|
desc.setOffset(baseDesc.offset());
|
|
|
|
// Iterate over the dimensions and apply size/stride permutation.
|
|
for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
|
|
int sourcePos = en.index();
|
|
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
|
|
desc.setSize(targetPos, baseDesc.size(sourcePos));
|
|
desc.setStride(targetPos, baseDesc.stride(sourcePos));
|
|
}
|
|
|
|
rewriter.replaceOp(op, {desc});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// YieldOp produces and LLVM::ReturnOp.
|
|
class YieldOpConversion : public ConvertToLLVMPattern {
|
|
public:
|
|
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
template <typename LinalgOp>
|
|
static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
|
|
return SmallVector<Type, 4>{op->getOperandTypes()};
|
|
}
|
|
|
|
template <>
|
|
SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
|
|
auto ctx = op->getContext();
|
|
auto indexedGenericOp = cast<IndexedGenericOp>(op);
|
|
auto numLoops = indexedGenericOp.getNumLoops();
|
|
|
|
SmallVector<Type, 4> result;
|
|
result.reserve(numLoops + op->getNumOperands());
|
|
for (unsigned i = 0; i < numLoops; ++i) {
|
|
result.push_back(IndexType::get(ctx));
|
|
}
|
|
for (auto type : op->getOperandTypes()) {
|
|
result.push_back(type);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
|
|
// If the library function does not exist, insert a declaration.
|
|
template <typename LinalgOp>
|
|
static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
|
PatternRewriter &rewriter) {
|
|
auto linalgOp = cast<LinalgOp>(op);
|
|
auto fnName = linalgOp.getLibraryCallName();
|
|
if (fnName.empty()) {
|
|
op->emitWarning("No library call defined for: ") << *op;
|
|
return {};
|
|
}
|
|
|
|
// fnName is a dynamic std::String, unique it via a SymbolRefAttr.
|
|
FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
|
|
auto module = op->getParentOfType<ModuleOp>();
|
|
if (module.lookupSymbol(fnName)) {
|
|
return fnNameAttr;
|
|
}
|
|
|
|
SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
|
|
assert(op->getNumResults() == 0 &&
|
|
"Library call for linalg operation can be generated only for ops that "
|
|
"have void return types");
|
|
auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
// Insert before module terminator.
|
|
rewriter.setInsertionPoint(module.getBody(),
|
|
std::prev(module.getBody()->end()));
|
|
rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
|
|
ArrayRef<NamedAttribute>{});
|
|
return fnNameAttr;
|
|
}
|
|
|
|
namespace {
|
|
|
|
// LinalgOpConversion<LinalgOp> creates a new call to the
|
|
// `LinalgOp::getLibraryCallName()` function.
|
|
// The implementation of the function can be either in the same module or in an
|
|
// externally linked library.
|
|
template <typename LinalgOp>
|
|
class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
|
|
public:
|
|
using OpRewritePattern<LinalgOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
|
|
if (!libraryCallName)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
|
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern specialization for CopyOp. This kicks in when both input
|
|
/// and output permutations are left unspecified or are the identity.
|
|
template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
|
|
public:
|
|
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CopyOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto inputPerm = op.inputPermutation();
|
|
if (inputPerm.hasValue() && !inputPerm->isIdentity())
|
|
return failure();
|
|
auto outputPerm = op.outputPermutation();
|
|
if (outputPerm.hasValue() && !outputPerm->isIdentity())
|
|
return failure();
|
|
|
|
auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
|
|
if (!libraryCallName)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
|
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern specialization for IndexedGenericOp.
|
|
template <>
|
|
class LinalgOpConversion<IndexedGenericOp>
|
|
: public OpRewritePattern<IndexedGenericOp> {
|
|
public:
|
|
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(IndexedGenericOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto libraryCallName =
|
|
getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
|
|
if (!libraryCallName)
|
|
return failure();
|
|
|
|
// TODO(pifon, ntv): Use induction variables values instead of zeros, when
|
|
// IndexedGenericOp is tiled.
|
|
auto zero = rewriter.create<mlir::ConstantOp>(
|
|
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
|
auto indexedGenericOp = cast<IndexedGenericOp>(op);
|
|
auto numLoops = indexedGenericOp.getNumLoops();
|
|
SmallVector<Value, 4> operands;
|
|
operands.reserve(numLoops + op.getNumOperands());
|
|
for (unsigned i = 0; i < numLoops; ++i) {
|
|
operands.push_back(zero);
|
|
}
|
|
for (auto operand : op.getOperands()) {
|
|
operands.push_back(operand);
|
|
}
|
|
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
|
|
ArrayRef<Type>{}, operands);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// A non-conversion rewrite pattern kicks in to convert CopyOp with
|
|
/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
|
|
/// This interplays together with TransposeOpConversion and
|
|
/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
|
|
class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
|
|
public:
|
|
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CopyOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value in = op.input(), out = op.output();
|
|
|
|
// If either inputPerm or outputPerm are non-identities, insert transposes.
|
|
auto inputPerm = op.inputPermutation();
|
|
if (inputPerm.hasValue() && !inputPerm->isIdentity())
|
|
in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
|
|
AffineMapAttr::get(*inputPerm));
|
|
auto outputPerm = op.outputPermutation();
|
|
if (outputPerm.hasValue() && !outputPerm->isIdentity())
|
|
out = rewriter.create<linalg::TransposeOp>(
|
|
op.getLoc(), out, AffineMapAttr::get(*outputPerm));
|
|
|
|
// If nothing was transposed, fail and let the conversion kick in.
|
|
if (in == op.input() && out == op.output())
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Populate the given list with patterns that convert from Linalg to Standard.
|
|
static void
|
|
populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
|
|
MLIRContext *ctx) {
|
|
// TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
|
|
// attribute values such as kernel striding and dilation.
|
|
// clang-format off
|
|
patterns.insert<
|
|
CopyTransposeConversion,
|
|
LinalgOpConversion<ConvOp>,
|
|
LinalgOpConversion<PoolingMaxOp>,
|
|
LinalgOpConversion<PoolingMinOp>,
|
|
LinalgOpConversion<PoolingSumOp>,
|
|
LinalgOpConversion<CopyOp>,
|
|
LinalgOpConversion<DotOp>,
|
|
LinalgOpConversion<FillOp>,
|
|
LinalgOpConversion<GenericOp>,
|
|
LinalgOpConversion<IndexedGenericOp>,
|
|
LinalgOpConversion<MatmulOp>,
|
|
LinalgOpConversion<MatvecOp>>(ctx);
|
|
// clang-format on
|
|
}
|
|
|
|
} // namespace
|
|
|
|
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
|
void mlir::populateLinalgToLLVMConversionPatterns(
|
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
|
MLIRContext *ctx) {
|
|
patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
|
|
TransposeOpConversion, YieldOpConversion>(ctx, converter);
|
|
|
|
// Populate the type conversions for the linalg types.
|
|
converter.addConversion(
|
|
[&](RangeType type) { return convertRangeType(type, converter); });
|
|
}
|
|
|
|
namespace {
|
|
struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> {
|
|
/// Include the generated pass utilities.
|
|
#define GEN_PASS_ConvertLinalgToLLVM
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
|
|
void runOnModule() override;
|
|
};
|
|
} // namespace
|
|
|
|
void ConvertLinalgToLLVMPass::runOnModule() {
|
|
auto module = getModule();
|
|
|
|
// Convert to the LLVM IR dialect using the converter defined above.
|
|
OwningRewritePatternList patterns;
|
|
LLVMTypeConverter converter(&getContext());
|
|
populateAffineToStdConversionPatterns(patterns, &getContext());
|
|
populateLoopToStdConversionPatterns(patterns, &getContext());
|
|
populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false,
|
|
/*emitCWrappers=*/true);
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
populateVectorToLLVMConversionPatterns(converter, patterns);
|
|
populateLinalgToStandardConversionPatterns(patterns, &getContext());
|
|
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
|
|
|
|
LLVMConversionTarget target(getContext());
|
|
target.addDynamicallyLegalOp<FuncOp>(
|
|
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
|
if (failed(applyFullConversion(module, target, patterns, &converter)))
|
|
signalPassFailure();
|
|
}
|
|
|
|
std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertLinalgToLLVMPass() {
|
|
return std::make_unique<ConvertLinalgToLLVMPass>();
|
|
}
|