## Introduction
This proposal describes the new op to be added to the `std` (and later moved `memref`)
dialect called `alloca_scope`.
## Motivation
Alloca operations are easy to misuse, especially if one relies on it while doing
rewriting/conversion passes. For example let's consider a simple example of two
independent dialects, one defines an op that wants to allocate on-stack and
another defines a construct that corresponds to some form of looping:
```
dialect1.looping_op {
%x = dialect2.stack_allocating_op
}
```
Since the dialects might not know about each other they are going to define a
lowering to std/scf/etc independently:
```
scf.for … {
%x_temp = std.alloca …
… // do some domain-specific work using %x_temp buffer
… // and store the result into %result
%x = %result
}
```
Later on the scf and `std.alloca` is going to be lowered to llvm using a
combination of `llvm.alloca` and unstructured control flow.
At this point the use of `%x_temp` is bound to either be either optimized by
llvm (for example using mem2reg) or in the worst case: perform an independent
stack allocation on each iteration of the loop. While the llvm optimizations are
likely to succeed they are not guaranteed to do so, and they provide
opportunities for surprising issues with unexpected use of stack size.
## Proposal
We propose a new operation that defines a finer-grain allocation scope for the
alloca-allocated memory called `alloca_scope`:
```
alloca_scope {
%x_temp = alloca …
...
}
```
Here the lifetime of `%x_temp` is going to be bound to the narrow annotated
region within `alloca_scope`. Moreover, one can also return values out of the
alloca_scope with an accompanying `alloca_scope.return` op (that behaves
similarly to `scf.yield`):
```
%result = alloca_scope {
%x_temp = alloca …
…
alloca_scope.return %myvalue
}
```
Under the hood the `alloca_scope` is going to lowered to a combination of
`llvm.intr.stacksave` and `llvm.intr.strackrestore` that are going to be invoked
automatically as control-flow enters and leaves the body of the `alloca_scope`.
The key value of the new op is to allow deterministic guaranteed stack use
through an explicit annotation in the code which is finer-grain than the
function-level scope of `AutomaticAllocationScope` interface. `alloca_scope`
can be inserted at arbitrary locations and doesn’t require non-trivial
transformations such as outlining.
## Which dialect
Before memref dialect is split, `alloca_scope` can temporarily reside in `std`
dialect, and later on be moved to `memref` together with the rest of
memory-related operations.
## Implementation
An implementation of the op is available [here](https://reviews.llvm.org/D97768).
Original commits:
* Add initial scaffolding for alloca_scope op
* Add alloca_scope.return op
* Add no region arguments and variadic results
* Add op descriptions
* Add failing test case
* Add another failing test
* Initial implementation of lowering for std.alloca_scope
* Fix backticks
* Fix getSuccessorRegions implementation
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D97768
4254 lines
177 KiB
C++
4254 lines
177 KiB
C++
//===- StandardToLLVM.cpp - Standard to LLVM dialect conversion -----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to convert MLIR standard and builtin dialects
|
|
// into the LLVM IR dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Interfaces/DataLayoutInterfaces.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Support/MathExtras.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include "mlir/Transforms/Utils.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include <functional>
|
|
|
|
using namespace mlir;
|
|
|
|
#define PASS_NAME "convert-std-to-llvm"
|
|
|
|
// Extract an LLVM IR type from the LLVM IR dialect type.
|
|
static Type unwrap(Type type) {
|
|
if (!type)
|
|
return nullptr;
|
|
auto *mlirContext = type.getContext();
|
|
if (!LLVM::isCompatibleType(type))
|
|
emitError(UnknownLoc::get(mlirContext),
|
|
"conversion resulted in a non-LLVM type ")
|
|
<< type;
|
|
return type;
|
|
}
|
|
|
|
/// Callback to convert function argument types. It converts a MemRef function
|
|
/// argument to a list of non-aggregate types containing descriptor
|
|
/// information, and an UnrankedmemRef function argument to a list containing
|
|
/// the rank and a pointer to a descriptor struct.
|
|
LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
|
|
Type type,
|
|
SmallVectorImpl<Type> &result) {
|
|
if (auto memref = type.dyn_cast<MemRefType>()) {
|
|
// In signatures, Memref descriptors are expanded into lists of
|
|
// non-aggregate values.
|
|
auto converted =
|
|
converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
|
|
if (converted.empty())
|
|
return failure();
|
|
result.append(converted.begin(), converted.end());
|
|
return success();
|
|
}
|
|
if (type.isa<UnrankedMemRefType>()) {
|
|
auto converted = converter.getUnrankedMemRefDescriptorFields();
|
|
if (converted.empty())
|
|
return failure();
|
|
result.append(converted.begin(), converted.end());
|
|
return success();
|
|
}
|
|
auto converted = converter.convertType(type);
|
|
if (!converted)
|
|
return failure();
|
|
result.push_back(converted);
|
|
return success();
|
|
}
|
|
|
|
/// Callback to convert function argument types. It converts MemRef function
|
|
/// arguments to bare pointers to the MemRef element type.
|
|
LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
|
|
Type type,
|
|
SmallVectorImpl<Type> &result) {
|
|
auto llvmTy = converter.convertCallingConventionType(type);
|
|
if (!llvmTy)
|
|
return failure();
|
|
|
|
result.push_back(llvmTy);
|
|
return success();
|
|
}
|
|
|
|
/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
|
|
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
|
|
const DataLayoutAnalysis *analysis)
|
|
: LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
|
|
|
|
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
|
|
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
|
|
const LowerToLLVMOptions &options,
|
|
const DataLayoutAnalysis *analysis)
|
|
: llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
|
|
dataLayoutAnalysis(analysis) {
|
|
assert(llvmDialect && "LLVM IR dialect is not registered");
|
|
|
|
// Register conversions for the builtin types.
|
|
addConversion([&](ComplexType type) { return convertComplexType(type); });
|
|
addConversion([&](FloatType type) { return convertFloatType(type); });
|
|
addConversion([&](FunctionType type) { return convertFunctionType(type); });
|
|
addConversion([&](IndexType type) { return convertIndexType(type); });
|
|
addConversion([&](IntegerType type) { return convertIntegerType(type); });
|
|
addConversion([&](MemRefType type) { return convertMemRefType(type); });
|
|
addConversion(
|
|
[&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
|
|
addConversion([&](VectorType type) { return convertVectorType(type); });
|
|
|
|
// LLVM-compatible types are legal, so add a pass-through conversion.
|
|
addConversion([](Type type) {
|
|
return LLVM::isCompatibleType(type) ? llvm::Optional<Type>(type)
|
|
: llvm::None;
|
|
});
|
|
|
|
// Materialization for memrefs creates descriptor structs from individual
|
|
// values constituting them, when descriptors are used, i.e. more than one
|
|
// value represents a memref.
|
|
addArgumentMaterialization(
|
|
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
|
|
Location loc) -> Optional<Value> {
|
|
if (inputs.size() == 1)
|
|
return llvm::None;
|
|
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
|
|
inputs);
|
|
});
|
|
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
|
|
ValueRange inputs,
|
|
Location loc) -> Optional<Value> {
|
|
if (inputs.size() == 1)
|
|
return llvm::None;
|
|
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
|
|
});
|
|
// Add generic source and target materializations to handle cases where
|
|
// non-LLVM types persist after an LLVM conversion.
|
|
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
|
|
ValueRange inputs,
|
|
Location loc) -> Optional<Value> {
|
|
if (inputs.size() != 1)
|
|
return llvm::None;
|
|
// FIXME: These should check LLVM::DialectCastOp can actually be constructed
|
|
// from the input and result.
|
|
return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
|
|
.getResult();
|
|
});
|
|
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
|
|
ValueRange inputs,
|
|
Location loc) -> Optional<Value> {
|
|
if (inputs.size() != 1)
|
|
return llvm::None;
|
|
// FIXME: These should check LLVM::DialectCastOp can actually be constructed
|
|
// from the input and result.
|
|
return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
|
|
.getResult();
|
|
});
|
|
}
|
|
|
|
/// Returns the MLIR context.
|
|
MLIRContext &LLVMTypeConverter::getContext() {
|
|
return *getDialect()->getContext();
|
|
}
|
|
|
|
Type LLVMTypeConverter::getIndexType() {
|
|
return IntegerType::get(&getContext(), getIndexTypeBitwidth());
|
|
}
|
|
|
|
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
|
|
return options.dataLayout.getPointerSizeInBits(addressSpace);
|
|
}
|
|
|
|
Type LLVMTypeConverter::convertIndexType(IndexType type) {
|
|
return getIndexType();
|
|
}
|
|
|
|
Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
|
|
return IntegerType::get(&getContext(), type.getWidth());
|
|
}
|
|
|
|
Type LLVMTypeConverter::convertFloatType(FloatType type) { return type; }
|
|
|
|
// Convert a `ComplexType` to an LLVM type. The result is a complex number
|
|
// struct with entries for the
|
|
// 1. real part and for the
|
|
// 2. imaginary part.
|
|
static constexpr unsigned kRealPosInComplexNumberStruct = 0;
|
|
static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
|
|
Type LLVMTypeConverter::convertComplexType(ComplexType type) {
|
|
auto elementType = convertType(type.getElementType());
|
|
return LLVM::LLVMStructType::getLiteral(&getContext(),
|
|
{elementType, elementType});
|
|
}
|
|
|
|
// Except for signatures, MLIR function types are converted into LLVM
|
|
// pointer-to-function types.
|
|
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
|
|
SignatureConversion conversion(type.getNumInputs());
|
|
Type converted =
|
|
convertFunctionSignature(type, /*isVariadic=*/false, conversion);
|
|
return LLVM::LLVMPointerType::get(converted);
|
|
}
|
|
|
|
// Function types are converted to LLVM Function types by recursively converting
|
|
// argument and result types. If MLIR Function has zero results, the LLVM
|
|
// Function has one VoidType result. If MLIR Function has more than one result,
|
|
// they are into an LLVM StructType in their order of appearance.
|
|
Type LLVMTypeConverter::convertFunctionSignature(
|
|
FunctionType funcTy, bool isVariadic,
|
|
LLVMTypeConverter::SignatureConversion &result) {
|
|
// Select the argument converter depending on the calling convention.
|
|
auto funcArgConverter = options.useBarePtrCallConv
|
|
? barePtrFuncArgTypeConverter
|
|
: structFuncArgTypeConverter;
|
|
// Convert argument types one by one and check for errors.
|
|
for (auto &en : llvm::enumerate(funcTy.getInputs())) {
|
|
Type type = en.value();
|
|
SmallVector<Type, 8> converted;
|
|
if (failed(funcArgConverter(*this, type, converted)))
|
|
return {};
|
|
result.addInputs(en.index(), converted);
|
|
}
|
|
|
|
SmallVector<Type, 8> argTypes;
|
|
argTypes.reserve(llvm::size(result.getConvertedTypes()));
|
|
for (Type type : result.getConvertedTypes())
|
|
argTypes.push_back(unwrap(type));
|
|
|
|
// If function does not return anything, create the void result type,
|
|
// if it returns on element, convert it, otherwise pack the result types into
|
|
// a struct.
|
|
Type resultType = funcTy.getNumResults() == 0
|
|
? LLVM::LLVMVoidType::get(&getContext())
|
|
: unwrap(packFunctionResults(funcTy.getResults()));
|
|
if (!resultType)
|
|
return {};
|
|
return LLVM::LLVMFunctionType::get(resultType, argTypes, isVariadic);
|
|
}
|
|
|
|
/// Converts the function type to a C-compatible format, in particular using
|
|
/// pointers to memref descriptors for arguments.
|
|
std::pair<Type, bool>
|
|
LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
|
|
SmallVector<Type, 4> inputs;
|
|
bool resultIsNowArg = false;
|
|
|
|
Type resultType = type.getNumResults() == 0
|
|
? LLVM::LLVMVoidType::get(&getContext())
|
|
: unwrap(packFunctionResults(type.getResults()));
|
|
if (!resultType)
|
|
return {};
|
|
|
|
if (auto structType = resultType.dyn_cast<LLVM::LLVMStructType>()) {
|
|
// Struct types cannot be safely returned via C interface. Make this a
|
|
// pointer argument, instead.
|
|
inputs.push_back(LLVM::LLVMPointerType::get(structType));
|
|
resultType = LLVM::LLVMVoidType::get(&getContext());
|
|
resultIsNowArg = true;
|
|
}
|
|
|
|
for (Type t : type.getInputs()) {
|
|
auto converted = convertType(t);
|
|
if (!converted || !LLVM::isCompatibleType(converted))
|
|
return {};
|
|
if (t.isa<MemRefType, UnrankedMemRefType>())
|
|
converted = LLVM::LLVMPointerType::get(converted);
|
|
inputs.push_back(converted);
|
|
}
|
|
|
|
return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg};
|
|
}
|
|
|
|
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
|
|
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
|
|
static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
|
|
static constexpr unsigned kSizePosInMemRefDescriptor = 3;
|
|
static constexpr unsigned kStridePosInMemRefDescriptor = 4;
|
|
|
|
/// Convert a memref type into a list of LLVM IR types that will form the
|
|
/// memref descriptor. The result contains the following types:
|
|
/// 1. The pointer to the allocated data buffer, followed by
|
|
/// 2. The pointer to the aligned data buffer, followed by
|
|
/// 3. A lowered `index`-type integer containing the distance between the
|
|
/// beginning of the buffer and the first element to be accessed through the
|
|
/// view, followed by
|
|
/// 4. An array containing as many `index`-type integers as the rank of the
|
|
/// MemRef: the array represents the size, in number of elements, of the memref
|
|
/// along the given dimension. For constant MemRef dimensions, the
|
|
/// corresponding size entry is a constant whose runtime value must match the
|
|
/// static value, followed by
|
|
/// 5. A second array containing as many `index`-type integers as the rank of
|
|
/// the MemRef: the second array represents the "stride" (in tensor abstraction
|
|
/// sense), i.e. the number of consecutive elements of the underlying buffer.
|
|
/// TODO: add assertions for the static cases.
|
|
///
|
|
/// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
|
|
/// are expanded into individual index-type elements.
|
|
///
|
|
/// template <typename Elem, typename Index, size_t Rank>
|
|
/// struct {
|
|
/// Elem *allocatedPtr;
|
|
/// Elem *alignedPtr;
|
|
/// Index offset;
|
|
/// Index sizes[Rank]; // omitted when rank == 0
|
|
/// Index strides[Rank]; // omitted when rank == 0
|
|
/// };
|
|
SmallVector<Type, 5>
|
|
LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
|
|
bool unpackAggregates) {
|
|
assert(isStrided(type) &&
|
|
"Non-strided layout maps must have been normalized away");
|
|
|
|
Type elementType = unwrap(convertType(type.getElementType()));
|
|
if (!elementType)
|
|
return {};
|
|
auto ptrTy =
|
|
LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
|
|
auto indexTy = getIndexType();
|
|
|
|
SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
|
|
auto rank = type.getRank();
|
|
if (rank == 0)
|
|
return results;
|
|
|
|
if (unpackAggregates)
|
|
results.insert(results.end(), 2 * rank, indexTy);
|
|
else
|
|
results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
|
|
return results;
|
|
}
|
|
|
|
unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
|
|
const DataLayout &layout) {
|
|
// Compute the descriptor size given that of its components indicated above.
|
|
unsigned space = type.getMemorySpaceAsInt();
|
|
return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
|
|
(1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
|
|
}
|
|
|
|
/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
|
|
/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
|
|
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
|
|
// When converting a MemRefType to a struct with descriptor fields, do not
|
|
// unpack the `sizes` and `strides` arrays.
|
|
SmallVector<Type, 5> types =
|
|
getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
|
|
if (types.empty())
|
|
return {};
|
|
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
|
}
|
|
|
|
static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0;
|
|
static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
|
|
|
|
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
|
|
/// that will form the unranked memref descriptor. In particular, the fields
|
|
/// for an unranked memref descriptor are:
|
|
/// 1. index-typed rank, the dynamic rank of this MemRef
|
|
/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
|
|
/// stack allocated (alloca) copy of a MemRef descriptor that got casted to
|
|
/// be unranked.
|
|
SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
|
|
return {getIndexType(),
|
|
LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8))};
|
|
}
|
|
|
|
unsigned
|
|
LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type,
|
|
const DataLayout &layout) {
|
|
// Compute the descriptor size given that of its components indicated above.
|
|
unsigned space = type.getMemorySpaceAsInt();
|
|
return layout.getTypeSize(getIndexType()) +
|
|
llvm::divideCeil(getPointerBitwidth(space), 8);
|
|
}
|
|
|
|
Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
|
|
if (!convertType(type.getElementType()))
|
|
return {};
|
|
return LLVM::LLVMStructType::getLiteral(&getContext(),
|
|
getUnrankedMemRefDescriptorFields());
|
|
}
|
|
|
|
/// Convert a memref type to a bare pointer to the memref element type.
|
|
Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
|
|
if (type.isa<UnrankedMemRefType>())
|
|
// Unranked memref is not supported in the bare pointer calling convention.
|
|
return {};
|
|
|
|
// Check that the memref has static shape, strides and offset. Otherwise, it
|
|
// cannot be lowered to a bare pointer.
|
|
auto memrefTy = type.cast<MemRefType>();
|
|
if (!memrefTy.hasStaticShape())
|
|
return {};
|
|
|
|
int64_t offset = 0;
|
|
SmallVector<int64_t, 4> strides;
|
|
if (failed(getStridesAndOffset(memrefTy, strides, offset)))
|
|
return {};
|
|
|
|
for (int64_t stride : strides)
|
|
if (ShapedType::isDynamicStrideOrOffset(stride))
|
|
return {};
|
|
|
|
if (ShapedType::isDynamicStrideOrOffset(offset))
|
|
return {};
|
|
|
|
Type elementType = unwrap(convertType(type.getElementType()));
|
|
if (!elementType)
|
|
return {};
|
|
return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt());
|
|
}
|
|
|
|
/// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type
|
|
/// when n > 1. For example, `vector<4 x f32>` remains as is while,
|
|
/// `vector<4x8x16xf32>` converts to `!llvm.array<4xarray<8 x vector<16xf32>>>`.
|
|
Type LLVMTypeConverter::convertVectorType(VectorType type) {
|
|
auto elementType = unwrap(convertType(type.getElementType()));
|
|
if (!elementType)
|
|
return {};
|
|
Type vectorType = VectorType::get(type.getShape().back(), elementType);
|
|
assert(LLVM::isCompatibleVectorType(vectorType) &&
|
|
"expected vector type compatible with the LLVM dialect");
|
|
auto shape = type.getShape();
|
|
for (int i = shape.size() - 2; i >= 0; --i)
|
|
vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
|
|
return vectorType;
|
|
}
|
|
|
|
/// Convert a type in the context of the default or bare pointer calling
|
|
/// convention. Calling convention sensitive types, such as MemRefType and
|
|
/// UnrankedMemRefType, are converted following the specific rules for the
|
|
/// calling convention. Calling convention independent types are converted
|
|
/// following the default LLVM type conversions.
|
|
Type LLVMTypeConverter::convertCallingConventionType(Type type) {
|
|
if (options.useBarePtrCallConv)
|
|
if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
|
|
return convertMemRefToBarePtr(memrefTy);
|
|
|
|
return convertType(type);
|
|
}
|
|
|
|
/// Promote the bare pointers in 'values' that resulted from memrefs to
|
|
/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
|
|
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
|
|
void LLVMTypeConverter::promoteBarePtrsToDescriptors(
|
|
ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
|
|
SmallVectorImpl<Value> &values) {
|
|
assert(stdTypes.size() == values.size() &&
|
|
"The number of types and values doesn't match");
|
|
for (unsigned i = 0, end = values.size(); i < end; ++i)
|
|
if (auto memrefTy = stdTypes[i].dyn_cast<MemRefType>())
|
|
values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
|
|
memrefTy, values[i]);
|
|
}
|
|
|
|
ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
|
|
MLIRContext *context,
|
|
LLVMTypeConverter &typeConverter,
|
|
PatternBenefit benefit)
|
|
: ConversionPattern(typeConverter, rootOpName, benefit, context) {}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// StructBuilder implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) {
|
|
assert(value != nullptr && "value cannot be null");
|
|
assert(LLVM::isCompatibleType(structType) && "expected llvm type");
|
|
}
|
|
|
|
Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
|
|
unsigned pos) {
|
|
Type type = structType.cast<LLVM::LLVMStructType>().getBody()[pos];
|
|
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
|
|
builder.getI64ArrayAttr(pos));
|
|
}
|
|
|
|
void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
|
|
Value ptr) {
|
|
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
|
|
builder.getI64ArrayAttr(pos));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ComplexStructBuilder implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
|
|
Location loc, Type type) {
|
|
Value val = builder.create<LLVM::UndefOp>(loc, type);
|
|
return ComplexStructBuilder(val);
|
|
}
|
|
|
|
void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc,
|
|
Value real) {
|
|
setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
|
|
}
|
|
|
|
Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) {
|
|
return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
|
|
}
|
|
|
|
void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc,
|
|
Value imaginary) {
|
|
setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
|
|
}
|
|
|
|
Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) {
|
|
return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefDescriptor implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Construct a helper for the given descriptor value.
|
|
MemRefDescriptor::MemRefDescriptor(Value descriptor)
|
|
: StructBuilder(descriptor) {
|
|
assert(value != nullptr && "value cannot be null");
|
|
indexType = value.getType()
|
|
.cast<LLVM::LLVMStructType>()
|
|
.getBody()[kOffsetPosInMemRefDescriptor];
|
|
}
|
|
|
|
/// Builds IR creating an `undef` value of the descriptor type.
|
|
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
|
|
Type descriptorType) {
|
|
|
|
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
|
|
return MemRefDescriptor(descriptor);
|
|
}
|
|
|
|
/// Builds IR creating a MemRef descriptor that represents `type` and
|
|
/// populates it with static shape and stride information extracted from the
|
|
/// type.
|
|
MemRefDescriptor
|
|
MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &typeConverter,
|
|
MemRefType type, Value memory) {
|
|
assert(type.hasStaticShape() && "unexpected dynamic shape");
|
|
|
|
// Extract all strides and offsets and verify they are static.
|
|
int64_t offset;
|
|
SmallVector<int64_t, 4> strides;
|
|
auto result = getStridesAndOffset(type, strides, offset);
|
|
(void)result;
|
|
assert(succeeded(result) && "unexpected failure in stride computation");
|
|
assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
|
|
"expected static offset");
|
|
assert(!llvm::any_of(strides, [](int64_t stride) {
|
|
return MemRefType::isDynamicStrideOrOffset(stride);
|
|
}) && "expected static strides");
|
|
|
|
auto convertedType = typeConverter.convertType(type);
|
|
assert(convertedType && "unexpected failure in memref type conversion");
|
|
|
|
auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
|
|
descr.setAllocatedPtr(builder, loc, memory);
|
|
descr.setAlignedPtr(builder, loc, memory);
|
|
descr.setConstantOffset(builder, loc, offset);
|
|
|
|
// Fill in sizes and strides
|
|
for (unsigned i = 0, e = type.getRank(); i != e; ++i) {
|
|
descr.setConstantSize(builder, loc, i, type.getDimSize(i));
|
|
descr.setConstantStride(builder, loc, i, strides[i]);
|
|
}
|
|
return descr;
|
|
}
|
|
|
|
/// Builds IR extracting the allocated pointer from the descriptor.
|
|
Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
|
|
return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
|
|
}
|
|
|
|
/// Builds IR inserting the allocated pointer into the descriptor.
|
|
void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
|
|
Value ptr) {
|
|
setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
|
|
}
|
|
|
|
/// Builds IR extracting the aligned pointer from the descriptor.
|
|
Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
|
|
return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
|
|
}
|
|
|
|
/// Builds IR inserting the aligned pointer into the descriptor.
|
|
void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
|
|
Value ptr) {
|
|
setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
|
|
}
|
|
|
|
// Creates a constant Op producing a value of `resultType` from an index-typed
|
|
// integer attribute.
|
|
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
|
Type resultType, int64_t value) {
|
|
return builder.create<LLVM::ConstantOp>(
|
|
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
|
|
}
|
|
|
|
/// Builds IR extracting the offset from the descriptor.
|
|
Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
|
|
return builder.create<LLVM::ExtractValueOp>(
|
|
loc, indexType, value,
|
|
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
|
|
}
|
|
|
|
/// Builds IR inserting the offset into the descriptor.
|
|
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
|
|
Value offset) {
|
|
value = builder.create<LLVM::InsertValueOp>(
|
|
loc, structType, value, offset,
|
|
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
|
|
}
|
|
|
|
/// Builds IR inserting the offset into the descriptor.
|
|
void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
|
|
uint64_t offset) {
|
|
setOffset(builder, loc,
|
|
createIndexAttrConstant(builder, loc, indexType, offset));
|
|
}
|
|
|
|
/// Builds IR extracting the pos-th size from the descriptor.
|
|
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
|
|
return builder.create<LLVM::ExtractValueOp>(
|
|
loc, indexType, value,
|
|
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
|
|
}
|
|
|
|
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
|
|
int64_t rank) {
|
|
auto indexPtrTy = LLVM::LLVMPointerType::get(indexType);
|
|
auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
|
|
auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
|
|
|
|
// Copy size values to stack-allocated memory.
|
|
auto zero = createIndexAttrConstant(builder, loc, indexType, 0);
|
|
auto one = createIndexAttrConstant(builder, loc, indexType, 1);
|
|
auto sizes = builder.create<LLVM::ExtractValueOp>(
|
|
loc, arrayTy, value,
|
|
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor}));
|
|
auto sizesPtr =
|
|
builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0);
|
|
builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
|
|
|
|
// Load an return size value of interest.
|
|
auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr,
|
|
ValueRange({zero, pos}));
|
|
return builder.create<LLVM::LoadOp>(loc, resultPtr);
|
|
}
|
|
|
|
/// Builds IR inserting the pos-th size into the descriptor
|
|
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
|
|
Value size) {
|
|
value = builder.create<LLVM::InsertValueOp>(
|
|
loc, structType, value, size,
|
|
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
|
|
}
|
|
|
|
void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
|
|
unsigned pos, uint64_t size) {
|
|
setSize(builder, loc, pos,
|
|
createIndexAttrConstant(builder, loc, indexType, size));
|
|
}
|
|
|
|
/// Builds IR extracting the pos-th stride from the descriptor.
|
|
Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
|
|
return builder.create<LLVM::ExtractValueOp>(
|
|
loc, indexType, value,
|
|
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
|
|
}
|
|
|
|
/// Builds IR inserting the pos-th stride into the descriptor
|
|
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
|
|
Value stride) {
|
|
value = builder.create<LLVM::InsertValueOp>(
|
|
loc, structType, value, stride,
|
|
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
|
|
}
|
|
|
|
void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
|
|
unsigned pos, uint64_t stride) {
|
|
setStride(builder, loc, pos,
|
|
createIndexAttrConstant(builder, loc, indexType, stride));
|
|
}
|
|
|
|
LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
|
|
return value.getType()
|
|
.cast<LLVM::LLVMStructType>()
|
|
.getBody()[kAlignedPtrPosInMemRefDescriptor]
|
|
.cast<LLVM::LLVMPointerType>();
|
|
}
|
|
|
|
/// Creates a MemRef descriptor structure from a list of individual values
|
|
/// composing that descriptor, in the following order:
|
|
/// - allocated pointer;
|
|
/// - aligned pointer;
|
|
/// - offset;
|
|
/// - <rank> sizes;
|
|
/// - <rank> shapes;
|
|
/// where <rank> is the MemRef rank as provided in `type`.
|
|
Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &converter, MemRefType type,
|
|
ValueRange values) {
|
|
Type llvmType = converter.convertType(type);
|
|
auto d = MemRefDescriptor::undef(builder, loc, llvmType);
|
|
|
|
d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
|
|
d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
|
|
d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
|
|
|
|
int64_t rank = type.getRank();
|
|
for (unsigned i = 0; i < rank; ++i) {
|
|
d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
|
|
d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
|
|
}
|
|
|
|
return d;
|
|
}
|
|
|
|
/// Builds IR extracting individual elements of a MemRef descriptor structure
|
|
/// and returning them as `results` list.
|
|
void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
|
|
MemRefType type,
|
|
SmallVectorImpl<Value> &results) {
|
|
int64_t rank = type.getRank();
|
|
results.reserve(results.size() + getNumUnpackedValues(type));
|
|
|
|
MemRefDescriptor d(packed);
|
|
results.push_back(d.allocatedPtr(builder, loc));
|
|
results.push_back(d.alignedPtr(builder, loc));
|
|
results.push_back(d.offset(builder, loc));
|
|
for (int64_t i = 0; i < rank; ++i)
|
|
results.push_back(d.size(builder, loc, i));
|
|
for (int64_t i = 0; i < rank; ++i)
|
|
results.push_back(d.stride(builder, loc, i));
|
|
}
|
|
|
|
/// Returns the number of non-aggregate values that would be produced by
|
|
/// `unpack`.
|
|
unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
|
|
// Two pointers, offset, <rank> sizes, <rank> shapes.
|
|
return 3 + 2 * type.getRank();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefDescriptorView implementation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MemRefDescriptorView::MemRefDescriptorView(ValueRange range)
|
|
: rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {}
|
|
|
|
Value MemRefDescriptorView::allocatedPtr() {
|
|
return elements[kAllocatedPtrPosInMemRefDescriptor];
|
|
}
|
|
|
|
Value MemRefDescriptorView::alignedPtr() {
|
|
return elements[kAlignedPtrPosInMemRefDescriptor];
|
|
}
|
|
|
|
Value MemRefDescriptorView::offset() {
|
|
return elements[kOffsetPosInMemRefDescriptor];
|
|
}
|
|
|
|
Value MemRefDescriptorView::size(unsigned pos) {
|
|
return elements[kSizePosInMemRefDescriptor + pos];
|
|
}
|
|
|
|
Value MemRefDescriptorView::stride(unsigned pos) {
|
|
return elements[kSizePosInMemRefDescriptor + rank + pos];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UnrankedMemRefDescriptor implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Construct a helper for the given descriptor value.
|
|
UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
|
|
: StructBuilder(descriptor) {}
|
|
|
|
/// Builds IR creating an `undef` value of the descriptor type.
|
|
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
|
|
Location loc,
|
|
Type descriptorType) {
|
|
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
|
|
return UnrankedMemRefDescriptor(descriptor);
|
|
}
|
|
Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
|
|
return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
|
|
}
|
|
void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
|
|
Value v) {
|
|
setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
|
|
}
|
|
Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
|
|
Location loc) {
|
|
return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
|
|
}
|
|
void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
|
|
Location loc, Value v) {
|
|
setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
|
|
}
|
|
|
|
/// Builds IR populating an unranked MemRef descriptor structure from a list
|
|
/// of individual constituent values in the following order:
|
|
/// - rank of the memref;
|
|
/// - pointer to the memref descriptor.
|
|
Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &converter,
|
|
UnrankedMemRefType type,
|
|
ValueRange values) {
|
|
Type llvmType = converter.convertType(type);
|
|
auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType);
|
|
|
|
d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
|
|
d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
|
|
return d;
|
|
}
|
|
|
|
/// Builds IR extracting individual elements that compose an unranked memref
|
|
/// descriptor and returns them as `results` list.
|
|
void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
|
|
Value packed,
|
|
SmallVectorImpl<Value> &results) {
|
|
UnrankedMemRefDescriptor d(packed);
|
|
results.reserve(results.size() + 2);
|
|
results.push_back(d.rank(builder, loc));
|
|
results.push_back(d.memRefDescPtr(builder, loc));
|
|
}
|
|
|
|
void UnrankedMemRefDescriptor::computeSizes(
|
|
OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
|
|
ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
|
|
if (values.empty())
|
|
return;
|
|
|
|
// Cache the index type.
|
|
Type indexType = typeConverter.getIndexType();
|
|
|
|
// Initialize shared constants.
|
|
Value one = createIndexAttrConstant(builder, loc, indexType, 1);
|
|
Value two = createIndexAttrConstant(builder, loc, indexType, 2);
|
|
Value pointerSize = createIndexAttrConstant(
|
|
builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
|
|
Value indexSize =
|
|
createIndexAttrConstant(builder, loc, indexType,
|
|
ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
|
|
|
|
sizes.reserve(sizes.size() + values.size());
|
|
for (UnrankedMemRefDescriptor desc : values) {
|
|
// Emit IR computing the memory necessary to store the descriptor. This
|
|
// assumes the descriptor to be
|
|
// { type*, type*, index, index[rank], index[rank] }
|
|
// and densely packed, so the total size is
|
|
// 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
|
|
// TODO: consider including the actual size (including eventual padding due
|
|
// to data layout) into the unranked descriptor.
|
|
Value doublePointerSize =
|
|
builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
|
|
|
|
// (1 + 2 * rank) * sizeof(index)
|
|
Value rank = desc.rank(builder, loc);
|
|
Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank);
|
|
Value doubleRankIncremented =
|
|
builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one);
|
|
Value rankIndexSize = builder.create<LLVM::MulOp>(
|
|
loc, indexType, doubleRankIncremented, indexSize);
|
|
|
|
// Total allocation size.
|
|
Value allocationSize = builder.create<LLVM::AddOp>(
|
|
loc, indexType, doublePointerSize, rankIndexSize);
|
|
sizes.push_back(allocationSize);
|
|
}
|
|
}
|
|
|
|
Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
|
|
Value memRefDescPtr,
|
|
Type elemPtrPtrType) {
|
|
|
|
Value elementPtrPtr =
|
|
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
|
|
return builder.create<LLVM::LoadOp>(loc, elementPtrPtr);
|
|
}
|
|
|
|
void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
|
|
Value memRefDescPtr,
|
|
Type elemPtrPtrType,
|
|
Value allocatedPtr) {
|
|
Value elementPtrPtr =
|
|
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
|
|
builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
|
|
}
|
|
|
|
Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &typeConverter,
|
|
Value memRefDescPtr,
|
|
Type elemPtrPtrType) {
|
|
Value elementPtrPtr =
|
|
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
|
|
|
|
Value one =
|
|
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
|
|
Value alignedGep = builder.create<LLVM::GEPOp>(
|
|
loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
|
|
return builder.create<LLVM::LoadOp>(loc, alignedGep);
|
|
}
|
|
|
|
void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &typeConverter,
|
|
Value memRefDescPtr,
|
|
Type elemPtrPtrType,
|
|
Value alignedPtr) {
|
|
Value elementPtrPtr =
|
|
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
|
|
|
|
Value one =
|
|
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1);
|
|
Value alignedGep = builder.create<LLVM::GEPOp>(
|
|
loc, elemPtrPtrType, elementPtrPtr, ValueRange({one}));
|
|
builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep);
|
|
}
|
|
|
|
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &typeConverter,
|
|
Value memRefDescPtr,
|
|
Type elemPtrPtrType) {
|
|
Value elementPtrPtr =
|
|
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
|
|
|
|
Value two =
|
|
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
|
|
Value offsetGep = builder.create<LLVM::GEPOp>(
|
|
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
|
|
offsetGep = builder.create<LLVM::BitcastOp>(
|
|
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
|
|
return builder.create<LLVM::LoadOp>(loc, offsetGep);
|
|
}
|
|
|
|
void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &typeConverter,
|
|
Value memRefDescPtr,
|
|
Type elemPtrPtrType, Value offset) {
|
|
Value elementPtrPtr =
|
|
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
|
|
|
|
Value two =
|
|
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2);
|
|
Value offsetGep = builder.create<LLVM::GEPOp>(
|
|
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two}));
|
|
offsetGep = builder.create<LLVM::BitcastOp>(
|
|
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep);
|
|
builder.create<LLVM::StoreOp>(loc, offset, offsetGep);
|
|
}
|
|
|
|
Value UnrankedMemRefDescriptor::sizeBasePtr(
|
|
OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
|
|
Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
|
|
Type elemPtrTy = elemPtrPtrType.getElementType();
|
|
Type indexTy = typeConverter.getIndexType();
|
|
Type structPtrTy =
|
|
LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral(
|
|
indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy}));
|
|
Value structPtr =
|
|
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
|
|
|
|
Type int32_type = unwrap(typeConverter.convertType(builder.getI32Type()));
|
|
Value zero =
|
|
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
|
|
Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
|
|
builder.getI32IntegerAttr(3));
|
|
return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
|
|
structPtr, ValueRange({zero, three}));
|
|
}
|
|
|
|
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter typeConverter,
|
|
Value sizeBasePtr, Value index) {
|
|
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
|
|
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
|
|
ValueRange({index}));
|
|
return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
|
|
}
|
|
|
|
void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter typeConverter,
|
|
Value sizeBasePtr, Value index,
|
|
Value size) {
|
|
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
|
|
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
|
|
ValueRange({index}));
|
|
builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
|
|
}
|
|
|
|
Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &typeConverter,
|
|
Value sizeBasePtr, Value rank) {
|
|
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
|
|
return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
|
|
ValueRange({rank}));
|
|
}
|
|
|
|
Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter typeConverter,
|
|
Value strideBasePtr, Value index,
|
|
Value stride) {
|
|
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
|
|
Value strideStoreGep = builder.create<LLVM::GEPOp>(
|
|
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
|
|
return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
|
|
}
|
|
|
|
void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter typeConverter,
|
|
Value strideBasePtr, Value index,
|
|
Value stride) {
|
|
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
|
|
Value strideStoreGep = builder.create<LLVM::GEPOp>(
|
|
loc, indexPtrTy, strideBasePtr, ValueRange({index}));
|
|
builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
|
|
}
|
|
|
|
LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const {
|
|
return static_cast<LLVMTypeConverter *>(
|
|
ConversionPattern::getTypeConverter());
|
|
}
|
|
|
|
LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
|
|
return *getTypeConverter()->getDialect();
|
|
}
|
|
|
|
Type ConvertToLLVMPattern::getIndexType() const {
|
|
return getTypeConverter()->getIndexType();
|
|
}
|
|
|
|
Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
|
|
return IntegerType::get(&getTypeConverter()->getContext(),
|
|
getTypeConverter()->getPointerBitwidth(addressSpace));
|
|
}
|
|
|
|
Type ConvertToLLVMPattern::getVoidType() const {
|
|
return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
|
|
}
|
|
|
|
Type ConvertToLLVMPattern::getVoidPtrType() const {
|
|
return LLVM::LLVMPointerType::get(
|
|
IntegerType::get(&getTypeConverter()->getContext(), 8));
|
|
}
|
|
|
|
Value ConvertToLLVMPattern::createIndexConstant(
|
|
ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
|
|
return createIndexAttrConstant(builder, loc, getIndexType(), value);
|
|
}
|
|
|
|
Value ConvertToLLVMPattern::getStridedElementPtr(
|
|
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
int64_t offset;
|
|
SmallVector<int64_t, 4> strides;
|
|
auto successStrides = getStridesAndOffset(type, strides, offset);
|
|
assert(succeeded(successStrides) && "unexpected non-strided memref");
|
|
(void)successStrides;
|
|
|
|
MemRefDescriptor memRefDescriptor(memRefDesc);
|
|
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
|
|
|
|
Value index;
|
|
if (offset != 0) // Skip if offset is zero.
|
|
index = MemRefType::isDynamicStrideOrOffset(offset)
|
|
? memRefDescriptor.offset(rewriter, loc)
|
|
: createIndexConstant(rewriter, loc, offset);
|
|
|
|
for (int i = 0, e = indices.size(); i < e; ++i) {
|
|
Value increment = indices[i];
|
|
if (strides[i] != 1) { // Skip if stride is 1.
|
|
Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
|
|
? memRefDescriptor.stride(rewriter, loc, i)
|
|
: createIndexConstant(rewriter, loc, strides[i]);
|
|
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
|
|
}
|
|
index =
|
|
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
|
|
}
|
|
|
|
Type elementPtrType = memRefDescriptor.getElementPtrType();
|
|
return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
|
|
: base;
|
|
}
|
|
|
|
// Check if the MemRefType `type` is supported by the lowering. We currently
|
|
// only support memrefs with identity maps.
|
|
bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
|
|
MemRefType type) const {
|
|
if (!typeConverter->convertType(type.getElementType()))
|
|
return false;
|
|
return type.getAffineMaps().empty() ||
|
|
llvm::all_of(type.getAffineMaps(),
|
|
[](AffineMap map) { return map.isIdentity(); });
|
|
}
|
|
|
|
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {
|
|
auto elementType = type.getElementType();
|
|
auto structElementType = unwrap(typeConverter->convertType(elementType));
|
|
return LLVM::LLVMPointerType::get(structElementType,
|
|
type.getMemorySpaceAsInt());
|
|
}
|
|
|
|
void ConvertToLLVMPattern::getMemRefDescriptorSizes(
|
|
Location loc, MemRefType memRefType, ValueRange dynamicSizes,
|
|
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
|
|
SmallVectorImpl<Value> &strides, Value &sizeBytes) const {
|
|
assert(isConvertibleAndHasIdentityMaps(memRefType) &&
|
|
"layout maps must have been normalized away");
|
|
assert(count(memRefType.getShape(), ShapedType::kDynamicSize) ==
|
|
static_cast<ssize_t>(dynamicSizes.size()) &&
|
|
"dynamicSizes size doesn't match dynamic sizes count in memref shape");
|
|
|
|
sizes.reserve(memRefType.getRank());
|
|
unsigned dynamicIndex = 0;
|
|
for (int64_t size : memRefType.getShape()) {
|
|
sizes.push_back(size == ShapedType::kDynamicSize
|
|
? dynamicSizes[dynamicIndex++]
|
|
: createIndexConstant(rewriter, loc, size));
|
|
}
|
|
|
|
// Strides: iterate sizes in reverse order and multiply.
|
|
int64_t stride = 1;
|
|
Value runningStride = createIndexConstant(rewriter, loc, 1);
|
|
strides.resize(memRefType.getRank());
|
|
for (auto i = memRefType.getRank(); i-- > 0;) {
|
|
strides[i] = runningStride;
|
|
|
|
int64_t size = memRefType.getShape()[i];
|
|
if (size == 0)
|
|
continue;
|
|
bool useSizeAsStride = stride == 1;
|
|
if (size == ShapedType::kDynamicSize)
|
|
stride = ShapedType::kDynamicSize;
|
|
if (stride != ShapedType::kDynamicSize)
|
|
stride *= size;
|
|
|
|
if (useSizeAsStride)
|
|
runningStride = sizes[i];
|
|
else if (stride == ShapedType::kDynamicSize)
|
|
runningStride =
|
|
rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
|
|
else
|
|
runningStride = createIndexConstant(rewriter, loc, stride);
|
|
}
|
|
|
|
// Buffer size in bytes.
|
|
Type elementPtrType = getElementPtrType(memRefType);
|
|
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
|
|
Value gepPtr = rewriter.create<LLVM::GEPOp>(
|
|
loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
|
|
sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
|
|
}
|
|
|
|
Value ConvertToLLVMPattern::getSizeInBytes(
|
|
Location loc, Type type, ConversionPatternRewriter &rewriter) const {
|
|
// Compute the size of an individual element. This emits the MLIR equivalent
|
|
// of the following sizeof(...) implementation in LLVM IR:
|
|
// %0 = getelementptr %elementType* null, %indexType 1
|
|
// %1 = ptrtoint %elementType* %0 to %indexType
|
|
// which is a common pattern of getting the size of a type in bytes.
|
|
auto convertedPtrType =
|
|
LLVM::LLVMPointerType::get(typeConverter->convertType(type));
|
|
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
|
|
auto gep = rewriter.create<LLVM::GEPOp>(
|
|
loc, convertedPtrType,
|
|
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
|
|
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
|
|
}
|
|
|
|
Value ConvertToLLVMPattern::getNumElements(
|
|
Location loc, ArrayRef<Value> shape,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
// Compute the total number of memref elements.
|
|
Value numElements =
|
|
shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front();
|
|
for (unsigned i = 1, e = shape.size(); i < e; ++i)
|
|
numElements = rewriter.create<LLVM::MulOp>(loc, numElements, shape[i]);
|
|
return numElements;
|
|
}
|
|
|
|
/// Creates and populates the memref descriptor struct given all its fields.
|
|
MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
|
|
Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
|
|
ArrayRef<Value> sizes, ArrayRef<Value> strides,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto structType = typeConverter->convertType(memRefType);
|
|
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
|
|
|
|
// Field 1: Allocated pointer, used for malloc/free.
|
|
memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
|
|
|
// Field 2: Actual aligned pointer to payload.
|
|
memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
|
|
|
|
// Field 3: Offset in aligned pointer.
|
|
memRefDescriptor.setOffset(rewriter, loc,
|
|
createIndexConstant(rewriter, loc, 0));
|
|
|
|
// Fields 4: Sizes.
|
|
for (auto en : llvm::enumerate(sizes))
|
|
memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
|
|
|
|
// Field 5: Strides.
|
|
for (auto en : llvm::enumerate(strides))
|
|
memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
|
|
|
|
return memRefDescriptor;
|
|
}
|
|
|
|
/// Only retain those attributes that are not constructed by
|
|
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
|
/// attributes.
|
|
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
|
|
bool filterArgAttrs,
|
|
SmallVectorImpl<NamedAttribute> &result) {
|
|
for (const auto &attr : attrs) {
|
|
if (attr.first == SymbolTable::getSymbolAttrName() ||
|
|
attr.first == function_like_impl::getTypeAttrName() ||
|
|
attr.first == "std.varargs" ||
|
|
(filterArgAttrs &&
|
|
attr.first == function_like_impl::getArgDictAttrName()))
|
|
continue;
|
|
result.push_back(attr);
|
|
}
|
|
}
|
|
|
|
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
|
|
/// arguments instead of unpacked arguments. This function can be called from C
|
|
/// by passing a pointer to a C struct corresponding to a memref descriptor.
|
|
/// Similarly, returned memrefs are passed via pointers to a C struct that is
|
|
/// passed as additional argument.
|
|
/// Internally, the auxiliary function unpacks the descriptor into individual
|
|
/// components and forwards them to `newFuncOp` and forwards the results to
|
|
/// the extra arguments.
|
|
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
|
|
LLVMTypeConverter &typeConverter,
|
|
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
|
|
auto type = funcOp.getType();
|
|
SmallVector<NamedAttribute, 4> attributes;
|
|
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
|
|
attributes);
|
|
Type wrapperFuncType;
|
|
bool resultIsNowArg;
|
|
std::tie(wrapperFuncType, resultIsNowArg) =
|
|
typeConverter.convertFunctionTypeCWrapper(type);
|
|
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
|
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
|
|
wrapperFuncType, LLVM::Linkage::External, attributes);
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
|
|
|
|
SmallVector<Value, 8> args;
|
|
size_t argOffset = resultIsNowArg ? 1 : 0;
|
|
for (auto &en : llvm::enumerate(type.getInputs())) {
|
|
Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
|
|
if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
|
|
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
|
|
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
|
|
continue;
|
|
}
|
|
if (en.value().isa<UnrankedMemRefType>()) {
|
|
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
|
|
UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
|
|
continue;
|
|
}
|
|
|
|
args.push_back(arg);
|
|
}
|
|
|
|
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
|
|
|
|
if (resultIsNowArg) {
|
|
rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
|
|
wrapperFuncOp.getArgument(0));
|
|
rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
|
|
} else {
|
|
rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
|
|
}
|
|
}
|
|
|
|
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
|
|
/// arguments instead of unpacked arguments. Creates a body for the (external)
|
|
/// `newFuncOp` that allocates a memref descriptor on stack, packs the
|
|
/// individual arguments into this descriptor and passes a pointer to it into
|
|
/// the auxiliary function. If the result of the function cannot be directly
|
|
/// returned, we write it to a special first argument that provides a pointer
|
|
/// to a corresponding struct. This auxiliary external function is now
|
|
/// compatible with functions defined in C using pointers to C structs
|
|
/// corresponding to a memref descriptor.
|
|
static void wrapExternalFunction(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &typeConverter,
|
|
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
|
|
Type wrapperType;
|
|
bool resultIsNowArg;
|
|
std::tie(wrapperType, resultIsNowArg) =
|
|
typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
|
|
// This conversion can only fail if it could not convert one of the argument
|
|
// types. But since it has been applied to a non-wrapper function before, it
|
|
// should have failed earlier and not reach this point at all.
|
|
assert(wrapperType && "unexpected type conversion failure");
|
|
|
|
SmallVector<NamedAttribute, 4> attributes;
|
|
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
|
|
attributes);
|
|
|
|
// Create the auxiliary function.
|
|
auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
|
|
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
|
|
wrapperType, LLVM::Linkage::External, attributes);
|
|
|
|
builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
|
|
|
|
// Get a ValueRange containing arguments.
|
|
FunctionType type = funcOp.getType();
|
|
SmallVector<Value, 8> args;
|
|
args.reserve(type.getNumInputs());
|
|
ValueRange wrapperArgsRange(newFuncOp.getArguments());
|
|
|
|
if (resultIsNowArg) {
|
|
// Allocate the struct on the stack and pass the pointer.
|
|
Type resultType =
|
|
wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
|
|
Value one = builder.create<LLVM::ConstantOp>(
|
|
loc, typeConverter.convertType(builder.getIndexType()),
|
|
builder.getIntegerAttr(builder.getIndexType(), 1));
|
|
Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
|
|
args.push_back(result);
|
|
}
|
|
|
|
// Iterate over the inputs of the original function and pack values into
|
|
// memref descriptors if the original type is a memref.
|
|
for (auto &en : llvm::enumerate(type.getInputs())) {
|
|
Value arg;
|
|
int numToDrop = 1;
|
|
auto memRefType = en.value().dyn_cast<MemRefType>();
|
|
auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
|
|
if (memRefType || unrankedMemRefType) {
|
|
numToDrop = memRefType
|
|
? MemRefDescriptor::getNumUnpackedValues(memRefType)
|
|
: UnrankedMemRefDescriptor::getNumUnpackedValues();
|
|
Value packed =
|
|
memRefType
|
|
? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
|
|
wrapperArgsRange.take_front(numToDrop))
|
|
: UnrankedMemRefDescriptor::pack(
|
|
builder, loc, typeConverter, unrankedMemRefType,
|
|
wrapperArgsRange.take_front(numToDrop));
|
|
|
|
auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
|
|
Value one = builder.create<LLVM::ConstantOp>(
|
|
loc, typeConverter.convertType(builder.getIndexType()),
|
|
builder.getIntegerAttr(builder.getIndexType(), 1));
|
|
Value allocated =
|
|
builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
|
|
builder.create<LLVM::StoreOp>(loc, packed, allocated);
|
|
arg = allocated;
|
|
} else {
|
|
arg = wrapperArgsRange[0];
|
|
}
|
|
|
|
args.push_back(arg);
|
|
wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
|
|
}
|
|
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
|
|
|
|
auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
|
|
|
|
if (resultIsNowArg) {
|
|
Value result = builder.create<LLVM::LoadOp>(loc, args.front());
|
|
builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
|
|
} else {
|
|
builder.create<LLVM::ReturnOp>(loc, call.getResults());
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
|
|
protected:
|
|
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
|
|
|
|
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
|
|
// to this legalization pattern.
|
|
LLVM::LLVMFuncOp
|
|
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
// Convert the original function arguments. They are converted using the
|
|
// LLVMTypeConverter provided to this legalization pattern.
|
|
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
|
|
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
|
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
|
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
|
if (!llvmType)
|
|
return nullptr;
|
|
|
|
// Propagate argument attributes to all converted arguments obtained after
|
|
// converting a given original argument.
|
|
SmallVector<NamedAttribute, 4> attributes;
|
|
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
|
|
attributes);
|
|
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
|
|
SmallVector<Attribute, 4> newArgAttrs(
|
|
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
|
|
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
|
|
auto mapping = result.getInputMapping(i);
|
|
assert(mapping.hasValue() &&
|
|
"unexpected deletion of function argument");
|
|
for (size_t j = 0; j < mapping->size; ++j)
|
|
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
|
|
}
|
|
attributes.push_back(
|
|
rewriter.getNamedAttr(function_like_impl::getArgDictAttrName(),
|
|
rewriter.getArrayAttr(newArgAttrs)));
|
|
}
|
|
|
|
// Create an LLVM function, use external linkage by default until MLIR
|
|
// functions have linkage.
|
|
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
|
funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External,
|
|
attributes);
|
|
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
|
newFuncOp.end());
|
|
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
|
|
&result)))
|
|
return nullptr;
|
|
|
|
return newFuncOp;
|
|
}
|
|
};
|
|
|
|
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
|
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
|
/// information.
|
|
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
|
struct FuncOpConversion : public FuncOpConversionBase {
|
|
FuncOpConversion(LLVMTypeConverter &converter)
|
|
: FuncOpConversionBase(converter) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
|
if (!newFuncOp)
|
|
return failure();
|
|
|
|
if (getTypeConverter()->getOptions().emitCWrappers ||
|
|
funcOp->getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
|
|
if (newFuncOp.isExternal())
|
|
wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
|
|
funcOp, newFuncOp);
|
|
else
|
|
wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
|
|
funcOp, newFuncOp);
|
|
}
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// FuncOp legalization pattern that converts MemRef arguments to bare pointers
|
|
/// to the MemRef element type. This will impact the calling convention and ABI.
|
|
struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
|
using FuncOpConversionBase::FuncOpConversionBase;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Store the type of memref-typed arguments before the conversion so that we
|
|
// can promote them to MemRef descriptor at the beginning of the function.
|
|
SmallVector<Type, 8> oldArgTypes =
|
|
llvm::to_vector<8>(funcOp.getType().getInputs());
|
|
|
|
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
|
if (!newFuncOp)
|
|
return failure();
|
|
if (newFuncOp.getBody().empty()) {
|
|
rewriter.eraseOp(funcOp);
|
|
return success();
|
|
}
|
|
|
|
// Promote bare pointers from memref arguments to memref descriptors at the
|
|
// beginning of the function so that all the memrefs in the function have a
|
|
// uniform representation.
|
|
Block *entryBlock = &newFuncOp.getBody().front();
|
|
auto blockArgs = entryBlock->getArguments();
|
|
assert(blockArgs.size() == oldArgTypes.size() &&
|
|
"The number of arguments and types doesn't match");
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(entryBlock);
|
|
for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
|
|
BlockArgument arg = std::get<0>(it);
|
|
Type argTy = std::get<1>(it);
|
|
|
|
// Unranked memrefs are not supported in the bare pointer calling
|
|
// convention. We should have bailed out before in the presence of
|
|
// unranked memrefs.
|
|
assert(!argTy.isa<UnrankedMemRefType>() &&
|
|
"Unranked memref is not supported");
|
|
auto memrefTy = argTy.dyn_cast<MemRefType>();
|
|
if (!memrefTy)
|
|
continue;
|
|
|
|
// Replace barePtr with a placeholder (undef), promote barePtr to a ranked
|
|
// or unranked memref descriptor and replace placeholder with the last
|
|
// instruction of the memref descriptor.
|
|
// TODO: The placeholder is needed to avoid replacing barePtr uses in the
|
|
// MemRef descriptor instructions. We may want to have a utility in the
|
|
// rewriter to properly handle this use case.
|
|
Location loc = funcOp.getLoc();
|
|
auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
|
|
rewriter.replaceUsesOfBlockArgument(arg, placeholder);
|
|
|
|
Value desc = MemRefDescriptor::fromStaticShape(
|
|
rewriter, loc, *getTypeConverter(), memrefTy, arg);
|
|
rewriter.replaceOp(placeholder, {desc});
|
|
}
|
|
|
|
rewriter.eraseOp(funcOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//////////////// Support for Lowering operations on n-D vectors ////////////////
|
|
// Helper struct to "unroll" operations on n-D vectors in terms of operations on
|
|
// 1-D LLVM vectors.
|
|
struct NDVectorTypeInfo {
|
|
// LLVM array struct which encodes n-D vectors.
|
|
Type llvmNDVectorTy;
|
|
// LLVM vector type which encodes the inner 1-D vector type.
|
|
Type llvm1DVectorTy;
|
|
// Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
|
|
SmallVector<int64_t, 4> arraySizes;
|
|
};
|
|
} // namespace
|
|
|
|
// For >1-D vector types, extracts the necessary information to iterate over all
|
|
// 1-D subvectors in the underlying llrepresentation of the n-D vector
|
|
// Iterates on the llvm array type until we hit a non-array type (which is
|
|
// asserted to be an llvm vector type).
|
|
static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
|
|
LLVMTypeConverter &converter) {
|
|
assert(vectorType.getRank() > 1 && "expected >1D vector type");
|
|
NDVectorTypeInfo info;
|
|
info.llvmNDVectorTy = converter.convertType(vectorType);
|
|
if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
|
|
info.llvmNDVectorTy = nullptr;
|
|
return info;
|
|
}
|
|
info.arraySizes.reserve(vectorType.getRank() - 1);
|
|
auto llvmTy = info.llvmNDVectorTy;
|
|
while (llvmTy.isa<LLVM::LLVMArrayType>()) {
|
|
info.arraySizes.push_back(
|
|
llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
|
|
llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
|
|
}
|
|
if (!LLVM::isCompatibleVectorType(llvmTy))
|
|
return info;
|
|
info.llvm1DVectorTy = llvmTy;
|
|
return info;
|
|
}
|
|
|
|
// Express `linearIndex` in terms of coordinates of `basis`.
|
|
// Returns the empty vector when linearIndex is out of the range [0, P] where
|
|
// P is the product of all the basis coordinates.
|
|
//
|
|
// Prerequisites:
|
|
// Basis is an array of nonnegative integers (signed type inherited from
|
|
// vector shape type).
|
|
static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
|
|
unsigned linearIndex) {
|
|
SmallVector<int64_t, 4> res;
|
|
res.reserve(basis.size());
|
|
for (unsigned basisElement : llvm::reverse(basis)) {
|
|
res.push_back(linearIndex % basisElement);
|
|
linearIndex = linearIndex / basisElement;
|
|
}
|
|
if (linearIndex > 0)
|
|
return {};
|
|
std::reverse(res.begin(), res.end());
|
|
return res;
|
|
}
|
|
|
|
// Iterate of linear index, convert to coords space and insert splatted 1-D
|
|
// vector in each position.
|
|
template <typename Lambda>
|
|
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
|
|
Lambda fun) {
|
|
unsigned ub = 1;
|
|
for (auto s : info.arraySizes)
|
|
ub *= s;
|
|
for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
|
|
auto coords = getCoordinates(info.arraySizes, linearIndex);
|
|
// Linear index is out of bounds, we are done.
|
|
if (coords.empty())
|
|
break;
|
|
assert(coords.size() == info.arraySizes.size());
|
|
auto position = builder.getI64ArrayAttr(coords);
|
|
fun(position);
|
|
}
|
|
}
|
|
////////////// End Support for Lowering operations on n-D vectors //////////////
|
|
|
|
/// Replaces the given operation "op" with a new operation of type "targetOp"
|
|
/// and given operands.
|
|
LogicalResult LLVM::detail::oneToOneRewrite(
|
|
Operation *op, StringRef targetOp, ValueRange operands,
|
|
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
|
|
unsigned numResults = op->getNumResults();
|
|
|
|
Type packedType;
|
|
if (numResults != 0) {
|
|
packedType = typeConverter.packFunctionResults(op->getResultTypes());
|
|
if (!packedType)
|
|
return failure();
|
|
}
|
|
|
|
// Create the operation through state since we don't know its C++ type.
|
|
OperationState state(op->getLoc(), targetOp);
|
|
state.addTypes(packedType);
|
|
state.addOperands(operands);
|
|
state.addAttributes(op->getAttrs());
|
|
Operation *newOp = rewriter.createOperation(state);
|
|
|
|
// If the operation produced 0 or 1 result, return them immediately.
|
|
if (numResults == 0)
|
|
return rewriter.eraseOp(op), success();
|
|
if (numResults == 1)
|
|
return rewriter.replaceOp(op, newOp->getResult(0)), success();
|
|
|
|
// Otherwise, it had been converted to an operation producing a structure.
|
|
// Extract individual results from the structure and return them as list.
|
|
SmallVector<Value, 4> results;
|
|
results.reserve(numResults);
|
|
for (unsigned i = 0; i < numResults; ++i) {
|
|
auto type = typeConverter.convertType(op->getResult(i).getType());
|
|
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
|
op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
|
|
}
|
|
rewriter.replaceOp(op, results);
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult handleMultidimensionalVectors(
|
|
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
|
|
std::function<Value(Type, ValueRange)> createOperand,
|
|
ConversionPatternRewriter &rewriter) {
|
|
auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
|
|
|
|
SmallVector<Type> operand1DVectorTypes;
|
|
for (Value operand : op->getOperands()) {
|
|
auto operandNDVectorType = operand.getType().cast<VectorType>();
|
|
auto operandTypeInfo =
|
|
extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
|
|
operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
|
|
}
|
|
auto resultTypeInfo =
|
|
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
|
|
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
|
|
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
|
|
auto loc = op->getLoc();
|
|
Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
|
|
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
|
|
// For this unrolled `position` corresponding to the `linearIndex`^th
|
|
// element, extract operand vectors
|
|
SmallVector<Value, 4> extractedOperands;
|
|
for (auto operand : llvm::enumerate(operands)) {
|
|
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
|
loc, operand1DVectorTypes[operand.index()], operand.value(),
|
|
position));
|
|
}
|
|
Value newVal = createOperand(result1DVectorTy, extractedOperands);
|
|
desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
|
|
newVal, position);
|
|
});
|
|
rewriter.replaceOp(op, desc);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult LLVM::detail::vectorOneToOneRewrite(
|
|
Operation *op, StringRef targetOp, ValueRange operands,
|
|
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
|
|
assert(!operands.empty());
|
|
|
|
// Cannot convert ops if their operands are not of LLVM type.
|
|
if (!llvm::all_of(operands.getTypes(),
|
|
[](Type t) { return isCompatibleType(t); }))
|
|
return failure();
|
|
|
|
auto llvmNDVectorTy = operands[0].getType();
|
|
if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
|
|
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
|
|
|
|
auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
|
|
ValueRange operands) {
|
|
OperationState state(op->getLoc(), targetOp);
|
|
state.addTypes(llvm1DVectorTy);
|
|
state.addOperands(operands);
|
|
state.addAttributes(op->getAttrs());
|
|
return rewriter.createOperation(state)->getResult(0);
|
|
};
|
|
|
|
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
|
|
rewriter);
|
|
}
|
|
|
|
namespace {
|
|
// Straightforward lowerings.
|
|
using AbsFOpLowering = VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp>;
|
|
using AddFOpLowering = VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp>;
|
|
using AddIOpLowering = VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp>;
|
|
using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>;
|
|
using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>;
|
|
using CopySignOpLowering =
|
|
VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>;
|
|
using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
|
|
using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
|
|
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
|
|
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
|
|
using FPExtOpLowering = VectorConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp>;
|
|
using FPToSIOpLowering = VectorConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp>;
|
|
using FPToUIOpLowering = VectorConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp>;
|
|
using FPTruncOpLowering = VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
|
|
using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
|
|
using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
|
|
using Log10OpLowering =
|
|
VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
|
|
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
|
|
using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
|
|
using MulFOpLowering = VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp>;
|
|
using MulIOpLowering = VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp>;
|
|
using NegFOpLowering = VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp>;
|
|
using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
|
|
using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
|
|
using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
|
|
using SIToFPOpLowering = VectorConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp>;
|
|
using SelectOpLowering = VectorConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
|
|
using SignExtendIOpLowering =
|
|
VectorConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp>;
|
|
using ShiftLeftOpLowering =
|
|
OneToOneConvertToLLVMPattern<ShiftLeftOp, LLVM::ShlOp>;
|
|
using SignedDivIOpLowering =
|
|
VectorConvertToLLVMPattern<SignedDivIOp, LLVM::SDivOp>;
|
|
using SignedRemIOpLowering =
|
|
VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp>;
|
|
using SignedShiftRightOpLowering =
|
|
OneToOneConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp>;
|
|
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
|
|
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
|
|
using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
|
|
using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
|
|
using TruncateIOpLowering = VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
|
|
using UIToFPOpLowering = VectorConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp>;
|
|
using UnsignedDivIOpLowering =
|
|
VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp>;
|
|
using UnsignedRemIOpLowering =
|
|
VectorConvertToLLVMPattern<UnsignedRemIOp, LLVM::URemOp>;
|
|
using UnsignedShiftRightOpLowering =
|
|
OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
|
|
using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
|
|
using ZeroExtendIOpLowering =
|
|
VectorConvertToLLVMPattern<ZeroExtendIOp, LLVM::ZExtOp>;
|
|
|
|
/// Lower `std.assert`. The default lowering calls the `abort` function if the
|
|
/// assertion is violated and has no effect otherwise. The failure message is
|
|
/// ignored by the default lowering but should be propagated by any custom
|
|
/// lowering.
|
|
struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
|
|
using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
AssertOp::Adaptor transformed(operands);
|
|
|
|
// Insert the `abort` declaration if necessary.
|
|
auto module = op->getParentOfType<ModuleOp>();
|
|
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
|
|
if (!abortFunc) {
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(module.getBody());
|
|
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
|
|
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
|
|
"abort", abortFuncTy);
|
|
}
|
|
|
|
// Split block at `assert` operation.
|
|
Block *opBlock = rewriter.getInsertionBlock();
|
|
auto opPosition = rewriter.getInsertionPoint();
|
|
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
|
|
|
|
// Generate IR to call `abort`.
|
|
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
|
|
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
|
|
rewriter.create<LLVM::UnreachableOp>(loc);
|
|
|
|
// Generate assertion test.
|
|
rewriter.setInsertionPointToEnd(opBlock);
|
|
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
|
op, transformed.arg(), continuationBlock, failureBlock);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
|
|
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// If constant refers to a function, convert it to "addressof".
|
|
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
|
|
auto type = typeConverter->convertType(op.getResult().getType());
|
|
if (!type || !LLVM::isCompatibleType(type))
|
|
return rewriter.notifyMatchFailure(op, "failed to convert result type");
|
|
|
|
auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type,
|
|
symbolRef.getValue());
|
|
for (const NamedAttribute &attr : op->getAttrs()) {
|
|
if (attr.first.strref() == "value")
|
|
continue;
|
|
newOp->setAttr(attr.first, attr.second);
|
|
}
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
return success();
|
|
}
|
|
|
|
// Calling into other scopes (non-flat reference) is not supported in LLVM.
|
|
if (op.getValue().isa<SymbolRefAttr>())
|
|
return rewriter.notifyMatchFailure(
|
|
op, "referring to a symbol outside of the current module");
|
|
|
|
return LLVM::detail::oneToOneRewrite(
|
|
op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(),
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
struct AllocOpLowering : public AllocLikeOpLLVMLowering {
|
|
AllocOpLowering(LLVMTypeConverter &converter)
|
|
: AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
|
|
converter) {}
|
|
|
|
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sizeBytes,
|
|
Operation *op) const override {
|
|
// Heap allocations.
|
|
memref::AllocOp allocOp = cast<memref::AllocOp>(op);
|
|
MemRefType memRefType = allocOp.getType();
|
|
|
|
Value alignment;
|
|
if (auto alignmentAttr = allocOp.alignment()) {
|
|
alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
|
|
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
|
|
// In the case where no alignment is specified, we may want to override
|
|
// `malloc's` behavior. `malloc` typically aligns at the size of the
|
|
// biggest scalar on a target HW. For non-scalars, use the natural
|
|
// alignment of the LLVM type given by the LLVM DataLayout.
|
|
alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
|
|
}
|
|
|
|
if (alignment) {
|
|
// Adjust the allocation size to consider alignment.
|
|
sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
|
|
}
|
|
|
|
// Allocate the underlying buffer and store a pointer to it in the MemRef
|
|
// descriptor.
|
|
Type elementPtrType = this->getElementPtrType(memRefType);
|
|
auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
|
|
allocOp->getParentOfType<ModuleOp>(), getIndexType());
|
|
auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
|
|
getVoidPtrType());
|
|
Value allocatedPtr =
|
|
rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
|
|
|
|
Value alignedPtr = allocatedPtr;
|
|
if (alignment) {
|
|
// Compute the aligned type pointer.
|
|
Value allocatedInt =
|
|
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
|
|
Value alignmentInt =
|
|
createAligned(rewriter, loc, allocatedInt, alignment);
|
|
alignedPtr =
|
|
rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
|
|
}
|
|
|
|
return std::make_tuple(allocatedPtr, alignedPtr);
|
|
}
|
|
};
|
|
|
|
struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
|
|
AlignedAllocOpLowering(LLVMTypeConverter &converter)
|
|
: AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
|
|
converter) {}
|
|
|
|
/// Returns the memref's element size in bytes using the data layout active at
|
|
/// `op`.
|
|
// TODO: there are other places where this is used. Expose publicly?
|
|
unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const {
|
|
const DataLayout *layout = &defaultLayout;
|
|
if (const DataLayoutAnalysis *analysis =
|
|
getTypeConverter()->getDataLayoutAnalysis()) {
|
|
layout = &analysis->getAbove(op);
|
|
}
|
|
Type elementType = memRefType.getElementType();
|
|
if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
|
|
return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
|
|
*layout);
|
|
if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
|
|
return getTypeConverter()->getUnrankedMemRefDescriptorSize(
|
|
memRefElementType, *layout);
|
|
return layout->getTypeSize(elementType);
|
|
}
|
|
|
|
/// Returns true if the memref size in bytes is known to be a multiple of
|
|
/// factor assuming the data layout active at `op`.
|
|
bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor,
|
|
Operation *op) const {
|
|
uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
|
|
for (unsigned i = 0, e = type.getRank(); i < e; i++) {
|
|
if (type.isDynamic(type.getDimSize(i)))
|
|
continue;
|
|
sizeDivisor = sizeDivisor * type.getDimSize(i);
|
|
}
|
|
return sizeDivisor % factor == 0;
|
|
}
|
|
|
|
/// Returns the alignment to be used for the allocation call itself.
|
|
/// aligned_alloc requires the allocation size to be a power of two, and the
|
|
/// allocation size to be a multiple of alignment,
|
|
int64_t getAllocationAlignment(memref::AllocOp allocOp) const {
|
|
if (Optional<uint64_t> alignment = allocOp.alignment())
|
|
return *alignment;
|
|
|
|
// Whenever we don't have alignment set, we will use an alignment
|
|
// consistent with the element type; since the allocation size has to be a
|
|
// power of two, we will bump to the next power of two if it already isn't.
|
|
auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp);
|
|
return std::max(kMinAlignedAllocAlignment,
|
|
llvm::PowerOf2Ceil(eltSizeBytes));
|
|
}
|
|
|
|
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sizeBytes,
|
|
Operation *op) const override {
|
|
// Heap allocations.
|
|
memref::AllocOp allocOp = cast<memref::AllocOp>(op);
|
|
MemRefType memRefType = allocOp.getType();
|
|
int64_t alignment = getAllocationAlignment(allocOp);
|
|
Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
|
|
|
|
// aligned_alloc requires size to be a multiple of alignment; we will pad
|
|
// the size to the next multiple if necessary.
|
|
if (!isMemRefSizeMultipleOf(memRefType, alignment, op))
|
|
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
|
|
|
|
Type elementPtrType = this->getElementPtrType(memRefType);
|
|
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
|
|
allocOp->getParentOfType<ModuleOp>(), getIndexType());
|
|
auto results =
|
|
createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
|
|
getVoidPtrType());
|
|
Value allocatedPtr =
|
|
rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
|
|
|
|
return std::make_tuple(allocatedPtr, allocatedPtr);
|
|
}
|
|
|
|
/// The minimum alignment to use with aligned_alloc (has to be a power of 2).
|
|
static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
|
|
|
|
/// Default layout to use in absence of the corresponding analysis.
|
|
DataLayout defaultLayout;
|
|
};
|
|
|
|
// Out of line definition, required till C++17.
|
|
constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment;
|
|
|
|
struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
|
|
AllocaOpLowering(LLVMTypeConverter &converter)
|
|
: AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
|
|
converter) {}
|
|
|
|
/// Allocates the underlying buffer using the right call. `allocatedBytePtr`
|
|
/// is set to null for stack allocations. `accessAlignment` is set if
|
|
/// alignment is needed post allocation (for eg. in conjunction with malloc).
|
|
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sizeBytes,
|
|
Operation *op) const override {
|
|
|
|
// With alloca, one gets a pointer to the element type right away.
|
|
// For stack allocations.
|
|
auto allocaOp = cast<memref::AllocaOp>(op);
|
|
auto elementPtrType = this->getElementPtrType(allocaOp.getType());
|
|
|
|
auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
|
|
loc, elementPtrType, sizeBytes,
|
|
allocaOp.alignment() ? *allocaOp.alignment() : 0);
|
|
|
|
return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
|
|
}
|
|
};
|
|
|
|
struct AllocaScopeOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
|
|
using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
Location loc = allocaScopeOp.getLoc();
|
|
|
|
// Split the current block before the AllocaScopeOp to create the inlining
|
|
// point.
|
|
auto *currentBlock = rewriter.getInsertionBlock();
|
|
auto *remainingOpsBlock =
|
|
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
|
|
Block *continueBlock;
|
|
if (allocaScopeOp.getNumResults() == 0) {
|
|
continueBlock = remainingOpsBlock;
|
|
} else {
|
|
continueBlock = rewriter.createBlock(remainingOpsBlock,
|
|
allocaScopeOp.getResultTypes());
|
|
rewriter.create<BranchOp>(loc, remainingOpsBlock);
|
|
}
|
|
|
|
// Inline body region.
|
|
Block *beforeBody = &allocaScopeOp.bodyRegion().front();
|
|
Block *afterBody = &allocaScopeOp.bodyRegion().back();
|
|
rewriter.inlineRegionBefore(allocaScopeOp.bodyRegion(), continueBlock);
|
|
|
|
// Save stack and then branch into the body of the region.
|
|
rewriter.setInsertionPointToEnd(currentBlock);
|
|
auto stackSaveOp =
|
|
rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
|
|
rewriter.create<BranchOp>(loc, beforeBody);
|
|
|
|
// Replace the alloca_scope return with a branch that jumps out of the body.
|
|
// Stack restore before leaving the body region.
|
|
rewriter.setInsertionPointToEnd(afterBody);
|
|
auto returnOp =
|
|
cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
|
|
auto branchOp = rewriter.replaceOpWithNewOp<BranchOp>(
|
|
returnOp, continueBlock, returnOp.results());
|
|
|
|
// Insert stack restore before jumping out the body of the region.
|
|
rewriter.setInsertionPoint(branchOp);
|
|
rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
|
|
|
|
// Replace the op with values return from the body region.
|
|
rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Copies the shaped descriptor part to (if `toDynamic` is set) or from
|
|
/// (otherwise) the dynamically allocated memory for any operands that were
|
|
/// unranked descriptors originally.
|
|
static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
|
|
LLVMTypeConverter &typeConverter,
|
|
TypeRange origTypes,
|
|
SmallVectorImpl<Value> &operands,
|
|
bool toDynamic) {
|
|
assert(origTypes.size() == operands.size() &&
|
|
"expected as may original types as operands");
|
|
|
|
// Find operands of unranked memref type and store them.
|
|
SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
|
|
for (unsigned i = 0, e = operands.size(); i < e; ++i)
|
|
if (origTypes[i].isa<UnrankedMemRefType>())
|
|
unrankedMemrefs.emplace_back(operands[i]);
|
|
|
|
if (unrankedMemrefs.empty())
|
|
return success();
|
|
|
|
// Compute allocation sizes.
|
|
SmallVector<Value, 4> sizes;
|
|
UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter,
|
|
unrankedMemrefs, sizes);
|
|
|
|
// Get frequently used types.
|
|
MLIRContext *context = builder.getContext();
|
|
Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
|
|
auto i1Type = IntegerType::get(context, 1);
|
|
Type indexType = typeConverter.getIndexType();
|
|
|
|
// Find the malloc and free, or declare them if necessary.
|
|
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
|
|
LLVM::LLVMFuncOp freeFunc, mallocFunc;
|
|
if (toDynamic)
|
|
mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
|
|
if (!toDynamic)
|
|
freeFunc = LLVM::lookupOrCreateFreeFn(module);
|
|
|
|
// Initialize shared constants.
|
|
Value zero =
|
|
builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false));
|
|
|
|
unsigned unrankedMemrefPos = 0;
|
|
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
|
|
Type type = origTypes[i];
|
|
if (!type.isa<UnrankedMemRefType>())
|
|
continue;
|
|
Value allocationSize = sizes[unrankedMemrefPos++];
|
|
UnrankedMemRefDescriptor desc(operands[i]);
|
|
|
|
// Allocate memory, copy, and free the source if necessary.
|
|
Value memory =
|
|
toDynamic
|
|
? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
|
|
.getResult(0)
|
|
: builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize,
|
|
/*alignment=*/0);
|
|
|
|
Value source = desc.memRefDescPtr(builder, loc);
|
|
builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero);
|
|
if (!toDynamic)
|
|
builder.create<LLVM::CallOp>(loc, freeFunc, source);
|
|
|
|
// Create a new descriptor. The same descriptor can be returned multiple
|
|
// times, attempting to modify its pointer can lead to memory leaks
|
|
// (allocated twice and overwritten) or double frees (the caller does not
|
|
// know if the descriptor points to the same memory).
|
|
Type descriptorType = typeConverter.convertType(type);
|
|
if (!descriptorType)
|
|
return failure();
|
|
auto updatedDesc =
|
|
UnrankedMemRefDescriptor::undef(builder, loc, descriptorType);
|
|
Value rank = desc.rank(builder, loc);
|
|
updatedDesc.setRank(builder, loc, rank);
|
|
updatedDesc.setMemRefDescPtr(builder, loc, memory);
|
|
|
|
operands[i] = updatedDesc;
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
|
|
// passes the pointer to the MemRef across function boundaries.
|
|
template <typename CallOpType>
|
|
struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
|
|
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
|
|
using Super = CallOpInterfaceLowering<CallOpType>;
|
|
using Base = ConvertOpToLLVMPattern<CallOpType>;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CallOpType callOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
typename CallOpType::Adaptor transformed(operands);
|
|
|
|
// Pack the result types into a struct.
|
|
Type packedResult = nullptr;
|
|
unsigned numResults = callOp.getNumResults();
|
|
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
|
|
|
|
if (numResults != 0) {
|
|
if (!(packedResult =
|
|
this->getTypeConverter()->packFunctionResults(resultTypes)))
|
|
return failure();
|
|
}
|
|
|
|
auto promoted = this->getTypeConverter()->promoteOperands(
|
|
callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands,
|
|
rewriter);
|
|
auto newOp = rewriter.create<LLVM::CallOp>(
|
|
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
|
|
promoted, callOp->getAttrs());
|
|
|
|
SmallVector<Value, 4> results;
|
|
if (numResults < 2) {
|
|
// If < 2 results, packing did not do anything and we can just return.
|
|
results.append(newOp.result_begin(), newOp.result_end());
|
|
} else {
|
|
// Otherwise, it had been converted to an operation producing a structure.
|
|
// Extract individual results from the structure and return them as list.
|
|
results.reserve(numResults);
|
|
for (unsigned i = 0; i < numResults; ++i) {
|
|
auto type =
|
|
this->typeConverter->convertType(callOp.getResult(i).getType());
|
|
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
|
callOp.getLoc(), type, newOp->getResult(0),
|
|
rewriter.getI64ArrayAttr(i)));
|
|
}
|
|
}
|
|
|
|
if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
|
|
// For the bare-ptr calling convention, promote memref results to
|
|
// descriptors.
|
|
assert(results.size() == resultTypes.size() &&
|
|
"The number of arguments and types doesn't match");
|
|
this->getTypeConverter()->promoteBarePtrsToDescriptors(
|
|
rewriter, callOp.getLoc(), resultTypes, results);
|
|
} else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(),
|
|
*this->getTypeConverter(),
|
|
resultTypes, results,
|
|
/*toDynamic=*/false))) {
|
|
return failure();
|
|
}
|
|
|
|
rewriter.replaceOp(callOp, results);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct CallOpLowering : public CallOpInterfaceLowering<CallOp> {
|
|
using Super::Super;
|
|
};
|
|
|
|
struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
|
|
using Super::Super;
|
|
};
|
|
|
|
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
|
|
// The memref descriptor being an SSA value, there is no need to clean it up
|
|
// in any way.
|
|
struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
|
|
using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
|
|
|
|
explicit DeallocOpLowering(LLVMTypeConverter &converter)
|
|
: ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::DeallocOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
assert(operands.size() == 1 && "dealloc takes one operand");
|
|
memref::DeallocOp::Adaptor transformed(operands);
|
|
|
|
// Insert the `free` declaration if it is not already present.
|
|
auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
|
|
MemRefDescriptor memref(transformed.memref());
|
|
Value casted = rewriter.create<LLVM::BitcastOp>(
|
|
op.getLoc(), getVoidPtrType(),
|
|
memref.allocatedPtr(rewriter, op.getLoc()));
|
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
|
op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Returns the LLVM type of the global variable given the memref type `type`.
|
|
static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
|
|
LLVMTypeConverter &typeConverter) {
|
|
// LLVM type for a global memref will be a multi-dimension array. For
|
|
// declarations or uninitialized global memrefs, we can potentially flatten
|
|
// this to a 1D array. However, for memref.global's with an initial value,
|
|
// we do not intend to flatten the ElementsAttribute when going from std ->
|
|
// LLVM dialect, so the LLVM type needs to me a multi-dimension array.
|
|
Type elementType = unwrap(typeConverter.convertType(type.getElementType()));
|
|
Type arrayTy = elementType;
|
|
// Shape has the outermost dim at index 0, so need to walk it backwards
|
|
for (int64_t dim : llvm::reverse(type.getShape()))
|
|
arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
|
|
return arrayTy;
|
|
}
|
|
|
|
/// GlobalMemrefOp is lowered to a LLVM Global Variable.
|
|
struct GlobalMemrefOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::GlobalOp> {
|
|
using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::GlobalOp global, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
MemRefType type = global.type().cast<MemRefType>();
|
|
if (!isConvertibleAndHasIdentityMaps(type))
|
|
return failure();
|
|
|
|
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
|
|
|
|
LLVM::Linkage linkage =
|
|
global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
|
|
|
|
Attribute initialValue = nullptr;
|
|
if (!global.isExternal() && !global.isUninitialized()) {
|
|
auto elementsAttr = global.initial_value()->cast<ElementsAttr>();
|
|
initialValue = elementsAttr;
|
|
|
|
// For scalar memrefs, the global variable created is of the element type,
|
|
// so unpack the elements attribute to extract the value.
|
|
if (type.getRank() == 0)
|
|
initialValue = elementsAttr.getValue({});
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
|
|
global, arrayTy, global.constant(), linkage, global.sym_name(),
|
|
initialValue, /*alignment=*/0, type.getMemorySpaceAsInt());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
|
|
/// the first element stashed into the descriptor. This reuses
|
|
/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
|
|
struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
|
|
GetGlobalMemrefOpLowering(LLVMTypeConverter &converter)
|
|
: AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
|
|
converter) {}
|
|
|
|
/// Buffer "allocation" for memref.get_global op is getting the address of
|
|
/// the global variable referenced.
|
|
std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value sizeBytes,
|
|
Operation *op) const override {
|
|
auto getGlobalOp = cast<memref::GetGlobalOp>(op);
|
|
MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
|
|
unsigned memSpace = type.getMemorySpaceAsInt();
|
|
|
|
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
|
|
auto addressOf = rewriter.create<LLVM::AddressOfOp>(
|
|
loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
|
|
|
|
// Get the address of the first element in the array by creating a GEP with
|
|
// the address of the GV as the base, and (rank + 1) number of 0 indices.
|
|
Type elementType =
|
|
unwrap(typeConverter->convertType(type.getElementType()));
|
|
Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
|
|
|
|
SmallVector<Value, 4> operands = {addressOf};
|
|
operands.insert(operands.end(), type.getRank() + 1,
|
|
createIndexConstant(rewriter, loc, 0));
|
|
auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
|
|
|
|
// We do not expect the memref obtained using `memref.get_global` to be
|
|
// ever deallocated. Set the allocated pointer to be known bad value to
|
|
// help debug if that ever happens.
|
|
auto intPtrType = getIntPtrType(memSpace);
|
|
Value deadBeefConst =
|
|
createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
|
|
auto deadBeefPtr =
|
|
rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
|
|
|
|
// Both allocated and aligned pointers are same. We could potentially stash
|
|
// a nullptr for the allocated pointer since we do not expect any dealloc.
|
|
return std::make_tuple(deadBeefPtr, gep);
|
|
}
|
|
};
|
|
|
|
// A `expm1` is converted into `exp - 1`.
|
|
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
|
|
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
math::ExpM1Op::Adaptor transformed(operands);
|
|
auto operandType = transformed.operand().getType();
|
|
|
|
if (!operandType || !LLVM::isCompatibleType(operandType))
|
|
return failure();
|
|
|
|
auto loc = op.getLoc();
|
|
auto resultType = op.getResult().getType();
|
|
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
|
|
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
|
|
|
|
if (!operandType.isa<LLVM::LLVMArrayType>()) {
|
|
LLVM::ConstantOp one;
|
|
if (LLVM::isCompatibleVectorType(operandType)) {
|
|
one = rewriter.create<LLVM::ConstantOp>(
|
|
loc, operandType,
|
|
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
|
|
} else {
|
|
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
|
|
}
|
|
auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
|
|
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
|
|
return success();
|
|
}
|
|
|
|
auto vectorType = resultType.dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return rewriter.notifyMatchFailure(op, "expected vector result type");
|
|
|
|
return handleMultidimensionalVectors(
|
|
op.getOperation(), operands, *getTypeConverter(),
|
|
[&](Type llvm1DVectorTy, ValueRange operands) {
|
|
auto splatAttr = SplatElementsAttr::get(
|
|
mlir::VectorType::get(
|
|
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
|
|
floatType),
|
|
floatOne);
|
|
auto one =
|
|
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
|
|
auto exp =
|
|
rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
|
|
return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
|
|
},
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
// A `log1p` is converted into `log(1 + ...)`.
|
|
struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
|
|
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
math::Log1pOp::Adaptor transformed(operands);
|
|
auto operandType = transformed.operand().getType();
|
|
|
|
if (!operandType || !LLVM::isCompatibleType(operandType))
|
|
return rewriter.notifyMatchFailure(op, "unsupported operand type");
|
|
|
|
auto loc = op.getLoc();
|
|
auto resultType = op.getResult().getType();
|
|
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
|
|
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
|
|
|
|
if (!operandType.isa<LLVM::LLVMArrayType>()) {
|
|
LLVM::ConstantOp one =
|
|
LLVM::isCompatibleVectorType(operandType)
|
|
? rewriter.create<LLVM::ConstantOp>(
|
|
loc, operandType,
|
|
SplatElementsAttr::get(resultType.cast<ShapedType>(),
|
|
floatOne))
|
|
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
|
|
|
|
auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
|
|
transformed.operand());
|
|
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
|
|
return success();
|
|
}
|
|
|
|
auto vectorType = resultType.dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return rewriter.notifyMatchFailure(op, "expected vector result type");
|
|
|
|
return handleMultidimensionalVectors(
|
|
op.getOperation(), operands, *getTypeConverter(),
|
|
[&](Type llvm1DVectorTy, ValueRange operands) {
|
|
auto splatAttr = SplatElementsAttr::get(
|
|
mlir::VectorType::get(
|
|
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
|
|
floatType),
|
|
floatOne);
|
|
auto one =
|
|
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
|
|
auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
|
|
operands[0]);
|
|
return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
|
|
},
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
// A `rsqrt` is converted into `1 / sqrt`.
|
|
struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
|
|
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
math::RsqrtOp::Adaptor transformed(operands);
|
|
auto operandType = transformed.operand().getType();
|
|
|
|
if (!operandType || !LLVM::isCompatibleType(operandType))
|
|
return failure();
|
|
|
|
auto loc = op.getLoc();
|
|
auto resultType = op.getResult().getType();
|
|
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
|
|
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
|
|
|
|
if (!operandType.isa<LLVM::LLVMArrayType>()) {
|
|
LLVM::ConstantOp one;
|
|
if (LLVM::isCompatibleVectorType(operandType)) {
|
|
one = rewriter.create<LLVM::ConstantOp>(
|
|
loc, operandType,
|
|
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
|
|
} else {
|
|
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
|
|
}
|
|
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
|
|
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
|
|
return success();
|
|
}
|
|
|
|
auto vectorType = resultType.dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return failure();
|
|
|
|
return handleMultidimensionalVectors(
|
|
op.getOperation(), operands, *getTypeConverter(),
|
|
[&](Type llvm1DVectorTy, ValueRange operands) {
|
|
auto splatAttr = SplatElementsAttr::get(
|
|
mlir::VectorType::get(
|
|
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
|
|
floatType),
|
|
floatOne);
|
|
auto one =
|
|
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
|
|
auto sqrt =
|
|
rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
|
|
return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
|
|
},
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
|
using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult match(memref::CastOp memRefCastOp) const override {
|
|
Type srcType = memRefCastOp.getOperand().getType();
|
|
Type dstType = memRefCastOp.getType();
|
|
|
|
// memref::CastOp reduce to bitcast in the ranked MemRef case and can be
|
|
// used for type erasure. For now they must preserve underlying element type
|
|
// and require source and result type to have the same rank. Therefore,
|
|
// perform a sanity check that the underlying structs are the same. Once op
|
|
// semantics are relaxed we can revisit.
|
|
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
|
|
return success(typeConverter->convertType(srcType) ==
|
|
typeConverter->convertType(dstType));
|
|
|
|
// At least one of the operands is unranked type
|
|
assert(srcType.isa<UnrankedMemRefType>() ||
|
|
dstType.isa<UnrankedMemRefType>());
|
|
|
|
// Unranked to unranked cast is disallowed
|
|
return !(srcType.isa<UnrankedMemRefType>() &&
|
|
dstType.isa<UnrankedMemRefType>())
|
|
? success()
|
|
: failure();
|
|
}
|
|
|
|
void rewrite(memref::CastOp memRefCastOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
memref::CastOp::Adaptor transformed(operands);
|
|
|
|
auto srcType = memRefCastOp.getOperand().getType();
|
|
auto dstType = memRefCastOp.getType();
|
|
auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
|
|
auto loc = memRefCastOp.getLoc();
|
|
|
|
// For ranked/ranked case, just keep the original descriptor.
|
|
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
|
|
return rewriter.replaceOp(memRefCastOp, {transformed.source()});
|
|
|
|
if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
|
|
// Casting ranked to unranked memref type
|
|
// Set the rank in the destination from the memref type
|
|
// Allocate space on the stack and copy the src memref descriptor
|
|
// Set the ptr in the destination to the stack space
|
|
auto srcMemRefType = srcType.cast<MemRefType>();
|
|
int64_t rank = srcMemRefType.getRank();
|
|
// ptr = AllocaOp sizeof(MemRefDescriptor)
|
|
auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
|
|
loc, transformed.source(), rewriter);
|
|
// voidptr = BitCastOp srcType* to void*
|
|
auto voidPtr =
|
|
rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
|
|
.getResult();
|
|
// rank = ConstantOp srcRank
|
|
auto rankVal = rewriter.create<LLVM::ConstantOp>(
|
|
loc, typeConverter->convertType(rewriter.getIntegerType(64)),
|
|
rewriter.getI64IntegerAttr(rank));
|
|
// undef = UndefOp
|
|
UnrankedMemRefDescriptor memRefDesc =
|
|
UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
|
|
// d1 = InsertValueOp undef, rank, 0
|
|
memRefDesc.setRank(rewriter, loc, rankVal);
|
|
// d2 = InsertValueOp d1, voidptr, 1
|
|
memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
|
|
rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
|
|
|
|
} else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
|
|
// Casting from unranked type to ranked.
|
|
// The operation is assumed to be doing a correct cast. If the destination
|
|
// type mismatches the unranked the type, it is undefined behavior.
|
|
UnrankedMemRefDescriptor memRefDesc(transformed.source());
|
|
// ptr = ExtractValueOp src, 1
|
|
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
|
|
// castPtr = BitCastOp i8* to structTy*
|
|
auto castPtr =
|
|
rewriter
|
|
.create<LLVM::BitcastOp>(
|
|
loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
|
|
.getResult();
|
|
// struct = LoadOp castPtr
|
|
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
|
|
rewriter.replaceOp(memRefCastOp, loadOp.getResult());
|
|
} else {
|
|
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
|
|
}
|
|
}
|
|
};
|
|
|
|
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
|
|
/// memref type. In unranked case, the fields are extracted from the underlying
|
|
/// ranked descriptor.
|
|
static void extractPointersAndOffset(Location loc,
|
|
ConversionPatternRewriter &rewriter,
|
|
LLVMTypeConverter &typeConverter,
|
|
Value originalOperand,
|
|
Value convertedOperand,
|
|
Value *allocatedPtr, Value *alignedPtr,
|
|
Value *offset = nullptr) {
|
|
Type operandType = originalOperand.getType();
|
|
if (operandType.isa<MemRefType>()) {
|
|
MemRefDescriptor desc(convertedOperand);
|
|
*allocatedPtr = desc.allocatedPtr(rewriter, loc);
|
|
*alignedPtr = desc.alignedPtr(rewriter, loc);
|
|
if (offset != nullptr)
|
|
*offset = desc.offset(rewriter, loc);
|
|
return;
|
|
}
|
|
|
|
unsigned memorySpace =
|
|
operandType.cast<UnrankedMemRefType>().getMemorySpaceAsInt();
|
|
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
|
|
Type llvmElementType = unwrap(typeConverter.convertType(elementType));
|
|
Type elementPtrPtrType = LLVM::LLVMPointerType::get(
|
|
LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
|
|
|
|
// Extract pointer to the underlying ranked memref descriptor and cast it to
|
|
// ElemType**.
|
|
UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
|
|
Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
|
|
|
|
*allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
|
|
rewriter, loc, underlyingDescPtr, elementPtrPtrType);
|
|
*alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
|
|
rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
|
|
if (offset != nullptr) {
|
|
*offset = UnrankedMemRefDescriptor::offset(
|
|
rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType);
|
|
}
|
|
}
|
|
|
|
struct MemRefReinterpretCastOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::ReinterpretCastOp castOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
memref::ReinterpretCastOp::Adaptor adaptor(operands,
|
|
castOp->getAttrDictionary());
|
|
Type srcType = castOp.source().getType();
|
|
|
|
Value descriptor;
|
|
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
|
|
adaptor, &descriptor)))
|
|
return failure();
|
|
rewriter.replaceOp(castOp, {descriptor});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
LogicalResult convertSourceMemRefToDescriptor(
|
|
ConversionPatternRewriter &rewriter, Type srcType,
|
|
memref::ReinterpretCastOp castOp,
|
|
memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
|
|
MemRefType targetMemRefType =
|
|
castOp.getResult().getType().cast<MemRefType>();
|
|
auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
|
|
.dyn_cast_or_null<LLVM::LLVMStructType>();
|
|
if (!llvmTargetDescriptorTy)
|
|
return failure();
|
|
|
|
// Create descriptor.
|
|
Location loc = castOp.getLoc();
|
|
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
|
|
|
// Set allocated and aligned pointers.
|
|
Value allocatedPtr, alignedPtr;
|
|
extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
|
|
castOp.source(), adaptor.source(), &allocatedPtr,
|
|
&alignedPtr);
|
|
desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
|
desc.setAlignedPtr(rewriter, loc, alignedPtr);
|
|
|
|
// Set offset.
|
|
if (castOp.isDynamicOffset(0))
|
|
desc.setOffset(rewriter, loc, adaptor.offsets()[0]);
|
|
else
|
|
desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
|
|
|
|
// Set sizes and strides.
|
|
unsigned dynSizeId = 0;
|
|
unsigned dynStrideId = 0;
|
|
for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
|
if (castOp.isDynamicSize(i))
|
|
desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]);
|
|
else
|
|
desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
|
|
|
|
if (castOp.isDynamicStride(i))
|
|
desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]);
|
|
else
|
|
desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
|
|
}
|
|
*descriptor = desc;
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct MemRefReshapeOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::ReshapeOp> {
|
|
using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::ReshapeOp reshapeOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto *op = reshapeOp.getOperation();
|
|
memref::ReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary());
|
|
Type srcType = reshapeOp.source().getType();
|
|
|
|
Value descriptor;
|
|
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
|
|
adaptor, &descriptor)))
|
|
return failure();
|
|
rewriter.replaceOp(op, {descriptor});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
LogicalResult
|
|
convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
|
|
Type srcType, memref::ReshapeOp reshapeOp,
|
|
memref::ReshapeOp::Adaptor adaptor,
|
|
Value *descriptor) const {
|
|
// Conversion for statically-known shape args is performed via
|
|
// `memref_reinterpret_cast`.
|
|
auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>();
|
|
if (shapeMemRefType.hasStaticShape())
|
|
return failure();
|
|
|
|
// The shape is a rank-1 tensor with unknown length.
|
|
Location loc = reshapeOp.getLoc();
|
|
MemRefDescriptor shapeDesc(adaptor.shape());
|
|
Value resultRank = shapeDesc.size(rewriter, loc, 0);
|
|
|
|
// Extract address space and element type.
|
|
auto targetType =
|
|
reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
|
|
unsigned addressSpace = targetType.getMemorySpaceAsInt();
|
|
Type elementType = targetType.getElementType();
|
|
|
|
// Create the unranked memref descriptor that holds the ranked one. The
|
|
// inner descriptor is allocated on stack.
|
|
auto targetDesc = UnrankedMemRefDescriptor::undef(
|
|
rewriter, loc, unwrap(typeConverter->convertType(targetType)));
|
|
targetDesc.setRank(rewriter, loc, resultRank);
|
|
SmallVector<Value, 4> sizes;
|
|
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
|
|
targetDesc, sizes);
|
|
Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
|
|
loc, getVoidPtrType(), sizes.front(), llvm::None);
|
|
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
|
|
|
|
// Extract pointers and offset from the source memref.
|
|
Value allocatedPtr, alignedPtr, offset;
|
|
extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
|
|
reshapeOp.source(), adaptor.source(),
|
|
&allocatedPtr, &alignedPtr, &offset);
|
|
|
|
// Set pointers and offset.
|
|
Type llvmElementType = unwrap(typeConverter->convertType(elementType));
|
|
auto elementPtrPtrType = LLVM::LLVMPointerType::get(
|
|
LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
|
|
UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
|
|
elementPtrPtrType, allocatedPtr);
|
|
UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
|
|
underlyingDescPtr,
|
|
elementPtrPtrType, alignedPtr);
|
|
UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
|
|
underlyingDescPtr, elementPtrPtrType,
|
|
offset);
|
|
|
|
// Use the offset pointer as base for further addressing. Copy over the new
|
|
// shape and compute strides. For this, we create a loop from rank-1 to 0.
|
|
Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
|
|
rewriter, loc, *getTypeConverter(), underlyingDescPtr,
|
|
elementPtrPtrType);
|
|
Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
|
|
rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
|
|
Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
|
|
Value oneIndex = createIndexConstant(rewriter, loc, 1);
|
|
Value resultRankMinusOne =
|
|
rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
|
|
|
|
Block *initBlock = rewriter.getInsertionBlock();
|
|
Type indexType = getTypeConverter()->getIndexType();
|
|
Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
|
|
|
|
Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
|
|
{indexType, indexType});
|
|
|
|
// Iterate over the remaining ops in initBlock and move them to condBlock.
|
|
BlockAndValueMapping map;
|
|
for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) {
|
|
rewriter.clone(*it, map);
|
|
rewriter.eraseOp(&*it);
|
|
}
|
|
|
|
rewriter.setInsertionPointToEnd(initBlock);
|
|
rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
|
|
condBlock);
|
|
rewriter.setInsertionPointToStart(condBlock);
|
|
Value indexArg = condBlock->getArgument(0);
|
|
Value strideArg = condBlock->getArgument(1);
|
|
|
|
Value zeroIndex = createIndexConstant(rewriter, loc, 0);
|
|
Value pred = rewriter.create<LLVM::ICmpOp>(
|
|
loc, IntegerType::get(rewriter.getContext(), 1),
|
|
LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
|
|
|
|
Block *bodyBlock =
|
|
rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
|
|
rewriter.setInsertionPointToStart(bodyBlock);
|
|
|
|
// Copy size from shape to descriptor.
|
|
Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
|
|
Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
|
|
loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
|
|
Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
|
|
UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
|
|
targetSizesBase, indexArg, size);
|
|
|
|
// Write stride value and compute next one.
|
|
UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
|
|
targetStridesBase, indexArg, strideArg);
|
|
Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
|
|
|
|
// Decrement loop counter and branch back.
|
|
Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
|
|
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
|
|
condBlock);
|
|
|
|
Block *remainder =
|
|
rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
|
|
|
|
// Hook up the cond exit to the remainder.
|
|
rewriter.setInsertionPointToEnd(condBlock);
|
|
rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder,
|
|
llvm::None);
|
|
|
|
// Reset position to beginning of new remainder block.
|
|
rewriter.setInsertionPointToStart(remainder);
|
|
|
|
*descriptor = targetDesc;
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct DialectCastOpLowering
|
|
: public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
|
|
using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(LLVM::DialectCastOp castOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
LLVM::DialectCastOp::Adaptor transformed(operands);
|
|
if (transformed.in().getType() !=
|
|
typeConverter->convertType(castOp.getType())) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOp(castOp, transformed.in());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// A `dim` is converted to a constant for static sizes and to an access to the
|
|
// size stored in the memref descriptor for dynamic sizes.
|
|
struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
|
|
using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::DimOp dimOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type operandType = dimOp.memrefOrTensor().getType();
|
|
if (operandType.isa<UnrankedMemRefType>()) {
|
|
rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef(
|
|
operandType, dimOp, operands, rewriter)});
|
|
|
|
return success();
|
|
}
|
|
if (operandType.isa<MemRefType>()) {
|
|
rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(
|
|
operandType, dimOp, operands, rewriter)});
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
Value extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = dimOp.getLoc();
|
|
memref::DimOp::Adaptor transformed(operands);
|
|
|
|
auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
|
|
auto scalarMemRefType =
|
|
MemRefType::get({}, unrankedMemRefType.getElementType());
|
|
unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt();
|
|
|
|
// Extract pointer to the underlying ranked descriptor and bitcast it to a
|
|
// memref<element_type> descriptor pointer to minimize the number of GEP
|
|
// operations.
|
|
UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor());
|
|
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
|
|
Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
|
|
loc,
|
|
LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
|
|
addressSpace),
|
|
underlyingRankedDesc);
|
|
|
|
// Get pointer to offset field of memref<element_type> descriptor.
|
|
Type indexPtrTy = LLVM::LLVMPointerType::get(
|
|
getTypeConverter()->getIndexType(), addressSpace);
|
|
Value two = rewriter.create<LLVM::ConstantOp>(
|
|
loc, typeConverter->convertType(rewriter.getI32Type()),
|
|
rewriter.getI32IntegerAttr(2));
|
|
Value offsetPtr = rewriter.create<LLVM::GEPOp>(
|
|
loc, indexPtrTy, scalarMemRefDescPtr,
|
|
ValueRange({createIndexConstant(rewriter, loc, 0), two}));
|
|
|
|
// The size value that we have to extract can be obtained using GEPop with
|
|
// `dimOp.index() + 1` index argument.
|
|
Value idxPlusOne = rewriter.create<LLVM::AddOp>(
|
|
loc, createIndexConstant(rewriter, loc, 1), transformed.index());
|
|
Value sizePtr = rewriter.create<LLVM::GEPOp>(loc, indexPtrTy, offsetPtr,
|
|
ValueRange({idxPlusOne}));
|
|
return rewriter.create<LLVM::LoadOp>(loc, sizePtr);
|
|
}
|
|
|
|
Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = dimOp.getLoc();
|
|
memref::DimOp::Adaptor transformed(operands);
|
|
// Take advantage if index is constant.
|
|
MemRefType memRefType = operandType.cast<MemRefType>();
|
|
if (Optional<int64_t> index = dimOp.getConstantIndex()) {
|
|
int64_t i = index.getValue();
|
|
if (memRefType.isDynamicDim(i)) {
|
|
// extract dynamic size from the memref descriptor.
|
|
MemRefDescriptor descriptor(transformed.memrefOrTensor());
|
|
return descriptor.size(rewriter, loc, i);
|
|
}
|
|
// Use constant for static size.
|
|
int64_t dimSize = memRefType.getDimSize(i);
|
|
return createIndexConstant(rewriter, loc, dimSize);
|
|
}
|
|
Value index = dimOp.index();
|
|
int64_t rank = memRefType.getRank();
|
|
MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor());
|
|
return memrefDescriptor.size(rewriter, loc, index, rank);
|
|
}
|
|
};
|
|
|
|
struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
|
|
using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(RankOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
Type operandType = op.memrefOrTensor().getType();
|
|
if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
|
|
UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor());
|
|
rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
|
|
return success();
|
|
}
|
|
if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
|
|
rewriter.replaceOp(
|
|
op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
// Common base for load and store operations on MemRefs. Restricts the match
|
|
// to supported MemRef types. Provides functionality to emit code accessing a
|
|
// specific element of the underlying data buffer.
|
|
template <typename Derived>
|
|
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
|
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
|
|
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
|
|
using Base = LoadStoreOpLowering<Derived>;
|
|
|
|
LogicalResult match(Derived op) const override {
|
|
MemRefType type = op.getMemRefType();
|
|
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
|
|
}
|
|
};
|
|
|
|
// Load operation is lowered to obtaining a pointer to the indexed element
|
|
// and loading it.
|
|
struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::LoadOp loadOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
memref::LoadOp::Adaptor transformed(operands);
|
|
auto type = loadOp.getMemRefType();
|
|
|
|
Value dataPtr =
|
|
getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(),
|
|
transformed.indices(), rewriter);
|
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, dataPtr);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Store operation is lowered to obtaining a pointer to the indexed element,
|
|
// and storing the given value to it.
|
|
struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::StoreOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = op.getMemRefType();
|
|
memref::StoreOp::Adaptor transformed(operands);
|
|
|
|
Value dataPtr =
|
|
getStridedElementPtr(op.getLoc(), type, transformed.memref(),
|
|
transformed.indices(), rewriter);
|
|
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
|
|
dataPtr);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// The prefetch operation is lowered in a way similar to the load operation
|
|
// except that the llvm.prefetch operation is used for replacement.
|
|
struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::PrefetchOp prefetchOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
memref::PrefetchOp::Adaptor transformed(operands);
|
|
auto type = prefetchOp.getMemRefType();
|
|
auto loc = prefetchOp.getLoc();
|
|
|
|
Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(),
|
|
transformed.indices(), rewriter);
|
|
|
|
// Replace with llvm.prefetch.
|
|
auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
|
|
auto isWrite = rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite()));
|
|
auto localityHint = rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmI32Type,
|
|
rewriter.getI32IntegerAttr(prefetchOp.localityHint()));
|
|
auto isData = rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache()));
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
|
|
localityHint, isData);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// The lowering of index_cast becomes an integer conversion since index becomes
|
|
// an integer. If the bit width of the source and target integer types is the
|
|
// same, just erase the cast. If the target type is wider, sign-extend the
|
|
// value, otherwise truncate it.
|
|
struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
|
|
using ConvertOpToLLVMPattern<IndexCastOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(IndexCastOp indexCastOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
IndexCastOpAdaptor transformed(operands);
|
|
|
|
auto targetType =
|
|
typeConverter->convertType(indexCastOp.getResult().getType())
|
|
.cast<IntegerType>();
|
|
auto sourceType = transformed.in().getType().cast<IntegerType>();
|
|
unsigned targetBits = targetType.getWidth();
|
|
unsigned sourceBits = sourceType.getWidth();
|
|
|
|
if (targetBits == sourceBits)
|
|
rewriter.replaceOp(indexCastOp, transformed.in());
|
|
else if (targetBits < sourceBits)
|
|
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(indexCastOp, targetType,
|
|
transformed.in());
|
|
else
|
|
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(indexCastOp, targetType,
|
|
transformed.in());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two
|
|
// enums share the numerical values so just cast.
|
|
template <typename LLVMPredType, typename StdPredType>
|
|
static LLVMPredType convertCmpPredicate(StdPredType pred) {
|
|
return static_cast<LLVMPredType>(pred);
|
|
}
|
|
|
|
struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
|
|
using ConvertOpToLLVMPattern<CmpIOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CmpIOp cmpiOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
CmpIOpAdaptor transformed(operands);
|
|
auto operandType = transformed.lhs().getType();
|
|
auto resultType = cmpiOp.getResult().getType();
|
|
|
|
// Handle the scalar and 1D vector cases.
|
|
if (!operandType.isa<LLVM::LLVMArrayType>()) {
|
|
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
|
|
cmpiOp, typeConverter->convertType(resultType),
|
|
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
|
|
transformed.lhs(), transformed.rhs());
|
|
return success();
|
|
}
|
|
|
|
auto vectorType = resultType.dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type");
|
|
|
|
return handleMultidimensionalVectors(
|
|
cmpiOp.getOperation(), operands, *getTypeConverter(),
|
|
[&](Type llvm1DVectorTy, ValueRange operands) {
|
|
CmpIOpAdaptor transformed(operands);
|
|
return rewriter.create<LLVM::ICmpOp>(
|
|
cmpiOp.getLoc(), llvm1DVectorTy,
|
|
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
|
|
transformed.lhs(), transformed.rhs());
|
|
},
|
|
rewriter);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
|
|
using ConvertOpToLLVMPattern<CmpFOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(CmpFOp cmpfOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
CmpFOpAdaptor transformed(operands);
|
|
auto operandType = transformed.lhs().getType();
|
|
auto resultType = cmpfOp.getResult().getType();
|
|
|
|
// Handle the scalar and 1D vector cases.
|
|
if (!operandType.isa<LLVM::LLVMArrayType>()) {
|
|
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
|
|
cmpfOp, typeConverter->convertType(resultType),
|
|
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
|
|
transformed.lhs(), transformed.rhs());
|
|
return success();
|
|
}
|
|
|
|
auto vectorType = resultType.dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type");
|
|
|
|
return handleMultidimensionalVectors(
|
|
cmpfOp.getOperation(), operands, *getTypeConverter(),
|
|
[&](Type llvm1DVectorTy, ValueRange operands) {
|
|
CmpFOpAdaptor transformed(operands);
|
|
return rewriter.create<LLVM::FCmpOp>(
|
|
cmpfOp.getLoc(), llvm1DVectorTy,
|
|
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
|
|
transformed.lhs(), transformed.rhs());
|
|
},
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
// Base class for LLVM IR lowering terminator operations with successors.
|
|
template <typename SourceOp, typename TargetOp>
|
|
struct OneToOneLLVMTerminatorLowering
|
|
: public ConvertOpToLLVMPattern<SourceOp> {
|
|
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
|
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
|
|
op->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Special lowering pattern for `ReturnOps`. Unlike all other operations,
|
|
// `ReturnOp` interacts with the function signature and must have as many
|
|
// operands as the function has return values. Because in LLVM IR, functions
|
|
// can only return 0 or 1 value, we pack multiple values into a structure type.
|
|
// Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
|
|
// necessary before returning it
|
|
struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> {
|
|
using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
unsigned numArguments = op.getNumOperands();
|
|
SmallVector<Value, 4> updatedOperands;
|
|
|
|
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
|
|
// For the bare-ptr calling convention, extract the aligned pointer to
|
|
// be returned from the memref descriptor.
|
|
for (auto it : llvm::zip(op->getOperands(), operands)) {
|
|
Type oldTy = std::get<0>(it).getType();
|
|
Value newOperand = std::get<1>(it);
|
|
if (oldTy.isa<MemRefType>()) {
|
|
MemRefDescriptor memrefDesc(newOperand);
|
|
newOperand = memrefDesc.alignedPtr(rewriter, loc);
|
|
} else if (oldTy.isa<UnrankedMemRefType>()) {
|
|
// Unranked memref is not supported in the bare pointer calling
|
|
// convention.
|
|
return failure();
|
|
}
|
|
updatedOperands.push_back(newOperand);
|
|
}
|
|
} else {
|
|
updatedOperands = llvm::to_vector<4>(operands);
|
|
(void)copyUnrankedDescriptors(rewriter, loc, *getTypeConverter(),
|
|
op.getOperands().getTypes(),
|
|
updatedOperands,
|
|
/*toDynamic=*/true);
|
|
}
|
|
|
|
// If ReturnOp has 0 or 1 operand, create it and return immediately.
|
|
if (numArguments == 0) {
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
|
op->getAttrs());
|
|
return success();
|
|
}
|
|
if (numArguments == 1) {
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
|
|
op, TypeRange(), updatedOperands, op->getAttrs());
|
|
return success();
|
|
}
|
|
|
|
// Otherwise, we need to pack the arguments into an LLVM struct type before
|
|
// returning.
|
|
auto packedType = getTypeConverter()->packFunctionResults(
|
|
llvm::to_vector<4>(op.getOperandTypes()));
|
|
|
|
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
|
|
for (unsigned i = 0; i < numArguments; ++i) {
|
|
packed = rewriter.create<LLVM::InsertValueOp>(
|
|
loc, packedType, packed, updatedOperands[i],
|
|
rewriter.getI64ArrayAttr(i));
|
|
}
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
|
|
op->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// FIXME: this should be tablegen'ed as well.
|
|
struct BranchOpLowering
|
|
: public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
|
|
using Super::Super;
|
|
};
|
|
struct CondBranchOpLowering
|
|
: public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
|
|
using Super::Super;
|
|
};
|
|
|
|
// The Splat operation is lowered to an insertelement + a shufflevector
|
|
// operation. Splat to only 1-d vector result types are lowered.
|
|
struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
|
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
|
|
if (!resultType || resultType.getRank() != 1)
|
|
return failure();
|
|
|
|
// First insert it into an undef vector so we can shuffle it.
|
|
auto vectorType = typeConverter->convertType(splatOp.getType());
|
|
Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
|
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
|
splatOp.getLoc(),
|
|
typeConverter->convertType(rewriter.getIntegerType(32)),
|
|
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
|
|
|
|
auto v = rewriter.create<LLVM::InsertElementOp>(
|
|
splatOp.getLoc(), vectorType, undef, splatOp.getOperand(), zero);
|
|
|
|
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
|
|
SmallVector<int32_t, 4> zeroValues(width, 0);
|
|
|
|
// Shuffle the value across the desired number of elements.
|
|
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
|
|
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
|
|
zeroAttrs);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// The Splat operation is lowered to an insertelement + a shufflevector
|
|
// operation. Splat to only 2+-d vector result types are lowered by the
|
|
// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
|
|
struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
|
|
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SplatOp splatOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SplatOp::Adaptor adaptor(operands);
|
|
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
|
|
if (!resultType || resultType.getRank() == 1)
|
|
return failure();
|
|
|
|
// First insert it into an undef vector so we can shuffle it.
|
|
auto loc = splatOp.getLoc();
|
|
auto vectorTypeInfo =
|
|
extractNDVectorTypeInfo(resultType, *getTypeConverter());
|
|
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
|
|
auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
|
|
if (!llvmNDVectorTy || !llvm1DVectorTy)
|
|
return failure();
|
|
|
|
// Construct returned value.
|
|
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
|
|
|
|
// Construct a 1-D vector with the splatted value that we insert in all the
|
|
// places within the returned descriptor.
|
|
Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
|
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
|
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
|
|
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
|
|
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
|
|
adaptor.input(), zero);
|
|
|
|
// Shuffle the value across the desired number of elements.
|
|
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
|
|
SmallVector<int32_t, 4> zeroValues(width, 0);
|
|
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
|
|
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
|
|
|
|
// Iterate of linear index, convert to coords space and insert splatted 1-D
|
|
// vector in each position.
|
|
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
|
|
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
|
|
position);
|
|
});
|
|
rewriter.replaceOp(splatOp, desc);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
|
|
static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
|
|
return llvm::to_vector<4>(
|
|
llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
|
|
return a.cast<IntegerAttr>().getInt();
|
|
}));
|
|
}
|
|
|
|
/// Conversion pattern that transforms a subview op into:
|
|
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
|
|
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
|
|
/// and stride.
|
|
/// The subview op is replaced by the descriptor.
|
|
struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
|
|
using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::SubViewOp subViewOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = subViewOp.getLoc();
|
|
|
|
auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
|
|
auto sourceElementTy =
|
|
typeConverter->convertType(sourceMemRefType.getElementType());
|
|
|
|
auto viewMemRefType = subViewOp.getType();
|
|
auto inferredType = memref::SubViewOp::inferResultType(
|
|
subViewOp.getSourceType(),
|
|
extractFromI64ArrayAttr(subViewOp.static_offsets()),
|
|
extractFromI64ArrayAttr(subViewOp.static_sizes()),
|
|
extractFromI64ArrayAttr(subViewOp.static_strides()))
|
|
.cast<MemRefType>();
|
|
auto targetElementTy =
|
|
typeConverter->convertType(viewMemRefType.getElementType());
|
|
auto targetDescTy = typeConverter->convertType(viewMemRefType);
|
|
if (!sourceElementTy || !targetDescTy || !targetElementTy ||
|
|
!LLVM::isCompatibleType(sourceElementTy) ||
|
|
!LLVM::isCompatibleType(targetElementTy) ||
|
|
!LLVM::isCompatibleType(targetDescTy))
|
|
return failure();
|
|
|
|
// Extract the offset and strides from the type.
|
|
int64_t offset;
|
|
SmallVector<int64_t, 4> strides;
|
|
auto successStrides = getStridesAndOffset(inferredType, strides, offset);
|
|
if (failed(successStrides))
|
|
return failure();
|
|
|
|
// Create the descriptor.
|
|
if (!LLVM::isCompatibleType(operands.front().getType()))
|
|
return failure();
|
|
MemRefDescriptor sourceMemRef(operands.front());
|
|
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
|
|
|
// Copy the buffer pointer from the old descriptor to the new one.
|
|
Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
|
|
Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
|
loc,
|
|
LLVM::LLVMPointerType::get(targetElementTy,
|
|
viewMemRefType.getMemorySpaceAsInt()),
|
|
extracted);
|
|
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
|
|
|
|
// Copy the aligned pointer from the old descriptor to the new one.
|
|
extracted = sourceMemRef.alignedPtr(rewriter, loc);
|
|
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
|
loc,
|
|
LLVM::LLVMPointerType::get(targetElementTy,
|
|
viewMemRefType.getMemorySpaceAsInt()),
|
|
extracted);
|
|
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
|
|
|
|
auto shape = viewMemRefType.getShape();
|
|
auto inferredShape = inferredType.getShape();
|
|
size_t inferredShapeRank = inferredShape.size();
|
|
size_t resultShapeRank = shape.size();
|
|
llvm::SmallDenseSet<unsigned> unusedDims =
|
|
computeRankReductionMask(inferredShape, shape).getValue();
|
|
|
|
// Extract strides needed to compute offset.
|
|
SmallVector<Value, 4> strideValues;
|
|
strideValues.reserve(inferredShapeRank);
|
|
for (unsigned i = 0; i < inferredShapeRank; ++i)
|
|
strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
|
|
|
|
// Offset.
|
|
auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
|
|
if (!ShapedType::isDynamicStrideOrOffset(offset)) {
|
|
targetMemRef.setConstantOffset(rewriter, loc, offset);
|
|
} else {
|
|
Value baseOffset = sourceMemRef.offset(rewriter, loc);
|
|
// `inferredShapeRank` may be larger than the number of offset operands
|
|
// because of trailing semantics. In this case, the offset is guaranteed
|
|
// to be interpreted as 0 and we can just skip the extra dimensions.
|
|
for (unsigned i = 0, e = std::min(inferredShapeRank,
|
|
subViewOp.getMixedOffsets().size());
|
|
i < e; ++i) {
|
|
Value offset =
|
|
// TODO: need OpFoldResult ODS adaptor to clean this up.
|
|
subViewOp.isDynamicOffset(i)
|
|
? operands[subViewOp.getIndexOfDynamicOffset(i)]
|
|
: rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmIndexType,
|
|
rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i)));
|
|
Value mul = rewriter.create<LLVM::MulOp>(loc, offset, strideValues[i]);
|
|
baseOffset = rewriter.create<LLVM::AddOp>(loc, baseOffset, mul);
|
|
}
|
|
targetMemRef.setOffset(rewriter, loc, baseOffset);
|
|
}
|
|
|
|
// Update sizes and strides.
|
|
SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
|
|
assert(mixedSizes.size() == mixedStrides.size() &&
|
|
"expected sizes and strides of equal length");
|
|
for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
|
|
i >= 0 && j >= 0; --i) {
|
|
if (unusedDims.contains(i))
|
|
continue;
|
|
|
|
// `i` may overflow subViewOp.getMixedSizes because of trailing semantics.
|
|
// In this case, the size is guaranteed to be interpreted as Dim and the
|
|
// stride as 1.
|
|
Value size, stride;
|
|
if (static_cast<unsigned>(i) >= mixedSizes.size()) {
|
|
size = rewriter.create<LLVM::DialectCastOp>(
|
|
loc, llvmIndexType,
|
|
rewriter.create<memref::DimOp>(loc, subViewOp.source(), i));
|
|
stride = rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmIndexType, rewriter.getI64IntegerAttr(1));
|
|
} else {
|
|
// TODO: need OpFoldResult ODS adaptor to clean this up.
|
|
size =
|
|
subViewOp.isDynamicSize(i)
|
|
? operands[subViewOp.getIndexOfDynamicSize(i)]
|
|
: rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmIndexType,
|
|
rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
|
|
if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
|
|
stride = rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
|
|
} else {
|
|
stride = subViewOp.isDynamicStride(i)
|
|
? operands[subViewOp.getIndexOfDynamicStride(i)]
|
|
: rewriter.create<LLVM::ConstantOp>(
|
|
loc, llvmIndexType,
|
|
rewriter.getI64IntegerAttr(
|
|
subViewOp.getStaticStride(i)));
|
|
stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
|
|
}
|
|
}
|
|
targetMemRef.setSize(rewriter, loc, j, size);
|
|
targetMemRef.setStride(rewriter, loc, j, stride);
|
|
j--;
|
|
}
|
|
|
|
rewriter.replaceOp(subViewOp, {targetMemRef});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern that transforms a transpose op into:
|
|
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
|
|
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
|
|
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
|
/// and stride. Size and stride are permutations of the original values.
|
|
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
|
/// The transpose op is replaced by the alloca'ed pointer.
|
|
class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
|
|
public:
|
|
using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::TransposeOp transposeOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = transposeOp.getLoc();
|
|
memref::TransposeOpAdaptor adaptor(operands);
|
|
MemRefDescriptor viewMemRef(adaptor.in());
|
|
|
|
// No permutation, early exit.
|
|
if (transposeOp.permutation().isIdentity())
|
|
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
|
|
|
|
auto targetMemRef = MemRefDescriptor::undef(
|
|
rewriter, loc, typeConverter->convertType(transposeOp.getShapedType()));
|
|
|
|
// Copy the base and aligned pointers from the old descriptor to the new
|
|
// one.
|
|
targetMemRef.setAllocatedPtr(rewriter, loc,
|
|
viewMemRef.allocatedPtr(rewriter, loc));
|
|
targetMemRef.setAlignedPtr(rewriter, loc,
|
|
viewMemRef.alignedPtr(rewriter, loc));
|
|
|
|
// Copy the offset pointer from the old descriptor to the new one.
|
|
targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
|
|
|
|
// Iterate over the dimensions and apply size/stride permutation.
|
|
for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
|
|
int sourcePos = en.index();
|
|
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
|
|
targetMemRef.setSize(rewriter, loc, targetPos,
|
|
viewMemRef.size(rewriter, loc, sourcePos));
|
|
targetMemRef.setStride(rewriter, loc, targetPos,
|
|
viewMemRef.stride(rewriter, loc, sourcePos));
|
|
}
|
|
|
|
rewriter.replaceOp(transposeOp, {targetMemRef});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern that transforms an op into:
|
|
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
|
|
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
|
|
/// and stride.
|
|
/// The view op is replaced by the descriptor.
|
|
struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
|
using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
|
|
|
|
// Build and return the value for the idx^th shape dimension, either by
|
|
// returning the constant shape dimension or counting the proper dynamic size.
|
|
Value getSize(ConversionPatternRewriter &rewriter, Location loc,
|
|
ArrayRef<int64_t> shape, ValueRange dynamicSizes,
|
|
unsigned idx) const {
|
|
assert(idx < shape.size());
|
|
if (!ShapedType::isDynamic(shape[idx]))
|
|
return createIndexConstant(rewriter, loc, shape[idx]);
|
|
// Count the number of dynamic dims in range [0, idx]
|
|
unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
|
|
return ShapedType::isDynamic(v);
|
|
});
|
|
return dynamicSizes[nDynamic];
|
|
}
|
|
|
|
// Build and return the idx^th stride, either by returning the constant stride
|
|
// or by computing the dynamic stride from the current `runningStride` and
|
|
// `nextSize`. The caller should keep a running stride and update it with the
|
|
// result returned by this function.
|
|
Value getStride(ConversionPatternRewriter &rewriter, Location loc,
|
|
ArrayRef<int64_t> strides, Value nextSize,
|
|
Value runningStride, unsigned idx) const {
|
|
assert(idx < strides.size());
|
|
if (!MemRefType::isDynamicStrideOrOffset(strides[idx]))
|
|
return createIndexConstant(rewriter, loc, strides[idx]);
|
|
if (nextSize)
|
|
return runningStride
|
|
? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
|
|
: nextSize;
|
|
assert(!runningStride);
|
|
return createIndexConstant(rewriter, loc, 1);
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::ViewOp viewOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = viewOp.getLoc();
|
|
memref::ViewOpAdaptor adaptor(operands);
|
|
|
|
auto viewMemRefType = viewOp.getType();
|
|
auto targetElementTy =
|
|
typeConverter->convertType(viewMemRefType.getElementType());
|
|
auto targetDescTy = typeConverter->convertType(viewMemRefType);
|
|
if (!targetDescTy || !targetElementTy ||
|
|
!LLVM::isCompatibleType(targetElementTy) ||
|
|
!LLVM::isCompatibleType(targetDescTy))
|
|
return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
|
|
failure();
|
|
|
|
int64_t offset;
|
|
SmallVector<int64_t, 4> strides;
|
|
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
|
|
if (failed(successStrides))
|
|
return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
|
|
assert(offset == 0 && "expected offset to be 0");
|
|
|
|
// Create the descriptor.
|
|
MemRefDescriptor sourceMemRef(adaptor.source());
|
|
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
|
|
|
// Field 1: Copy the allocated pointer, used for malloc/free.
|
|
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
|
|
auto srcMemRefType = viewOp.source().getType().cast<MemRefType>();
|
|
Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
|
loc,
|
|
LLVM::LLVMPointerType::get(targetElementTy,
|
|
srcMemRefType.getMemorySpaceAsInt()),
|
|
allocatedPtr);
|
|
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
|
|
|
|
// Field 2: Copy the actual aligned pointer to payload.
|
|
Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
|
|
alignedPtr = rewriter.create<LLVM::GEPOp>(loc, alignedPtr.getType(),
|
|
alignedPtr, adaptor.byte_shift());
|
|
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
|
|
loc,
|
|
LLVM::LLVMPointerType::get(targetElementTy,
|
|
srcMemRefType.getMemorySpaceAsInt()),
|
|
alignedPtr);
|
|
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
|
|
|
|
// Field 3: The offset in the resulting type must be 0. This is because of
|
|
// the type change: an offset on srcType* may not be expressible as an
|
|
// offset on dstType*.
|
|
targetMemRef.setOffset(rewriter, loc,
|
|
createIndexConstant(rewriter, loc, offset));
|
|
|
|
// Early exit for 0-D corner case.
|
|
if (viewMemRefType.getRank() == 0)
|
|
return rewriter.replaceOp(viewOp, {targetMemRef}), success();
|
|
|
|
// Fields 4 and 5: Update sizes and strides.
|
|
if (strides.back() != 1)
|
|
return viewOp.emitWarning("cannot cast to non-contiguous shape"),
|
|
failure();
|
|
Value stride = nullptr, nextSize = nullptr;
|
|
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
|
|
// Update size.
|
|
Value size =
|
|
getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i);
|
|
targetMemRef.setSize(rewriter, loc, i, size);
|
|
// Update stride.
|
|
stride = getStride(rewriter, loc, strides, nextSize, stride, i);
|
|
targetMemRef.setStride(rewriter, loc, i, stride);
|
|
nextSize = size;
|
|
}
|
|
|
|
rewriter.replaceOp(viewOp, {targetMemRef});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct AssumeAlignmentOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::AssumeAlignmentOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
memref::AssumeAlignmentOp::Adaptor transformed(operands);
|
|
Value memref = transformed.memref();
|
|
unsigned alignment = op.alignment();
|
|
auto loc = op.getLoc();
|
|
|
|
MemRefDescriptor memRefDescriptor(memref);
|
|
Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
|
|
|
|
// Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
|
|
// the asserted memref.alignedPtr isn't used anywhere else, as the real
|
|
// users like load/store/views always re-extract memref.alignedPtr as they
|
|
// get lowered.
|
|
//
|
|
// This relies on LLVM's CSE optimization (potentially after SROA), since
|
|
// after CSE all memref.alignedPtr instances get de-duplicated into the same
|
|
// pointer SSA value.
|
|
auto intPtrType =
|
|
getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
|
|
Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
|
|
Value mask =
|
|
createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
|
|
Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
|
|
rewriter.create<LLVM::AssumeOp>(
|
|
loc, rewriter.create<LLVM::ICmpOp>(
|
|
loc, LLVM::ICmpPredicate::eq,
|
|
rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
|
|
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
/// Try to match the kind of a std.atomic_rmw to determine whether to use a
|
|
/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
|
|
static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
|
|
switch (atomicOp.kind()) {
|
|
case AtomicRMWKind::addf:
|
|
return LLVM::AtomicBinOp::fadd;
|
|
case AtomicRMWKind::addi:
|
|
return LLVM::AtomicBinOp::add;
|
|
case AtomicRMWKind::assign:
|
|
return LLVM::AtomicBinOp::xchg;
|
|
case AtomicRMWKind::maxs:
|
|
return LLVM::AtomicBinOp::max;
|
|
case AtomicRMWKind::maxu:
|
|
return LLVM::AtomicBinOp::umax;
|
|
case AtomicRMWKind::mins:
|
|
return LLVM::AtomicBinOp::min;
|
|
case AtomicRMWKind::minu:
|
|
return LLVM::AtomicBinOp::umin;
|
|
default:
|
|
return llvm::None;
|
|
}
|
|
llvm_unreachable("Invalid AtomicRMWKind");
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (failed(match(atomicOp)))
|
|
return failure();
|
|
auto maybeKind = matchSimpleAtomicOp(atomicOp);
|
|
if (!maybeKind)
|
|
return failure();
|
|
AtomicRMWOp::Adaptor adaptor(operands);
|
|
auto resultType = adaptor.value().getType();
|
|
auto memRefType = atomicOp.getMemRefType();
|
|
auto dataPtr =
|
|
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
|
|
adaptor.indices(), rewriter);
|
|
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
|
|
atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
|
|
LLVM::AtomicOrdering::acq_rel);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
|
|
/// retried until it succeeds in atomically storing a new value into memory.
|
|
///
|
|
/// +---------------------------------+
|
|
/// | <code before the AtomicRMWOp> |
|
|
/// | <compute initial %loaded> |
|
|
/// | br loop(%loaded) |
|
|
/// +---------------------------------+
|
|
/// |
|
|
/// -------| |
|
|
/// | v v
|
|
/// | +--------------------------------+
|
|
/// | | loop(%loaded): |
|
|
/// | | <body contents> |
|
|
/// | | %pair = cmpxchg |
|
|
/// | | %ok = %pair[0] |
|
|
/// | | %new = %pair[1] |
|
|
/// | | cond_br %ok, end, loop(%new) |
|
|
/// | +--------------------------------+
|
|
/// | | |
|
|
/// |----------- |
|
|
/// v
|
|
/// +--------------------------------+
|
|
/// | end: |
|
|
/// | <code after the AtomicRMWOp> |
|
|
/// +--------------------------------+
|
|
///
|
|
struct GenericAtomicRMWOpLowering
|
|
: public LoadStoreOpLowering<GenericAtomicRMWOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto loc = atomicOp.getLoc();
|
|
GenericAtomicRMWOp::Adaptor adaptor(operands);
|
|
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
|
|
|
|
// Split the block into initial, loop, and ending parts.
|
|
auto *initBlock = rewriter.getInsertionBlock();
|
|
auto *loopBlock =
|
|
rewriter.createBlock(initBlock->getParent(),
|
|
std::next(Region::iterator(initBlock)), valueType);
|
|
auto *endBlock = rewriter.createBlock(
|
|
loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
|
|
|
|
// Operations range to be moved to `endBlock`.
|
|
auto opsToMoveStart = atomicOp->getIterator();
|
|
auto opsToMoveEnd = initBlock->back().getIterator();
|
|
|
|
// Compute the loaded value and branch to the loop block.
|
|
rewriter.setInsertionPointToEnd(initBlock);
|
|
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
|
|
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
|
|
adaptor.indices(), rewriter);
|
|
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
|
|
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
|
|
|
|
// Prepare the body of the loop block.
|
|
rewriter.setInsertionPointToStart(loopBlock);
|
|
|
|
// Clone the GenericAtomicRMWOp region and extract the result.
|
|
auto loopArgument = loopBlock->getArgument(0);
|
|
BlockAndValueMapping mapping;
|
|
mapping.map(atomicOp.getCurrentValue(), loopArgument);
|
|
Block &entryBlock = atomicOp.body().front();
|
|
for (auto &nestedOp : entryBlock.without_terminator()) {
|
|
Operation *clone = rewriter.clone(nestedOp, mapping);
|
|
mapping.map(nestedOp.getResults(), clone->getResults());
|
|
}
|
|
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
|
|
|
|
// Prepare the epilog of the loop block.
|
|
// Append the cmpxchg op to the end of the loop block.
|
|
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
|
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
|
auto boolType = IntegerType::get(rewriter.getContext(), 1);
|
|
auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
|
|
{valueType, boolType});
|
|
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
|
|
loc, pairType, dataPtr, loopArgument, result, successOrdering,
|
|
failureOrdering);
|
|
// Extract the %new_loaded and %ok values from the pair.
|
|
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
|
|
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
|
|
Value ok = rewriter.create<LLVM::ExtractValueOp>(
|
|
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
|
|
|
|
// Conditionally branch to the end or back to the loop depending on %ok.
|
|
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
|
|
loopBlock, newLoaded);
|
|
|
|
rewriter.setInsertionPointToEnd(endBlock);
|
|
moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
|
|
std::next(opsToMoveEnd), rewriter);
|
|
|
|
// The 'result' of the atomic_rmw op is the newly loaded value.
|
|
rewriter.replaceOp(atomicOp, {newLoaded});
|
|
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
// Clones a segment of ops [start, end) and erases the original.
|
|
void moveOpsRange(ValueRange oldResult, ValueRange newResult,
|
|
Block::iterator start, Block::iterator end,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
BlockAndValueMapping mapping;
|
|
mapping.map(oldResult, newResult);
|
|
SmallVector<Operation *, 2> opsToErase;
|
|
for (auto it = start; it != end; ++it) {
|
|
rewriter.clone(*it, mapping);
|
|
opsToErase.push_back(&*it);
|
|
}
|
|
for (auto *it : opsToErase)
|
|
rewriter.eraseOp(it);
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
|
|
void mlir::populateStdToLLVMNonMemoryConversionPatterns(
|
|
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
|
// FIXME: this should be tablegen'ed
|
|
// clang-format off
|
|
patterns.add<
|
|
AbsFOpLowering,
|
|
AddFOpLowering,
|
|
AddIOpLowering,
|
|
AllocaOpLowering,
|
|
AllocaScopeOpLowering,
|
|
AndOpLowering,
|
|
AssertOpLowering,
|
|
AtomicRMWOpLowering,
|
|
BranchOpLowering,
|
|
CallIndirectOpLowering,
|
|
CallOpLowering,
|
|
CeilFOpLowering,
|
|
CmpFOpLowering,
|
|
CmpIOpLowering,
|
|
CondBranchOpLowering,
|
|
CopySignOpLowering,
|
|
CosOpLowering,
|
|
ConstantOpLowering,
|
|
DialectCastOpLowering,
|
|
DivFOpLowering,
|
|
ExpOpLowering,
|
|
Exp2OpLowering,
|
|
ExpM1OpLowering,
|
|
FloorFOpLowering,
|
|
FmaFOpLowering,
|
|
GenericAtomicRMWOpLowering,
|
|
LogOpLowering,
|
|
Log10OpLowering,
|
|
Log1pOpLowering,
|
|
Log2OpLowering,
|
|
FPExtOpLowering,
|
|
FPToSIOpLowering,
|
|
FPToUIOpLowering,
|
|
FPTruncOpLowering,
|
|
IndexCastOpLowering,
|
|
MulFOpLowering,
|
|
MulIOpLowering,
|
|
NegFOpLowering,
|
|
OrOpLowering,
|
|
PowFOpLowering,
|
|
PrefetchOpLowering,
|
|
RemFOpLowering,
|
|
ReturnOpLowering,
|
|
RsqrtOpLowering,
|
|
SIToFPOpLowering,
|
|
SelectOpLowering,
|
|
ShiftLeftOpLowering,
|
|
SignExtendIOpLowering,
|
|
SignedDivIOpLowering,
|
|
SignedRemIOpLowering,
|
|
SignedShiftRightOpLowering,
|
|
SinOpLowering,
|
|
SplatOpLowering,
|
|
SplatNdOpLowering,
|
|
SqrtOpLowering,
|
|
SubFOpLowering,
|
|
SubIOpLowering,
|
|
TruncateIOpLowering,
|
|
UIToFPOpLowering,
|
|
UnsignedDivIOpLowering,
|
|
UnsignedRemIOpLowering,
|
|
UnsignedShiftRightOpLowering,
|
|
XOrOpLowering,
|
|
ZeroExtendIOpLowering>(converter);
|
|
// clang-format on
|
|
}
|
|
|
|
void mlir::populateStdToLLVMMemoryConversionPatterns(
|
|
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
|
// clang-format off
|
|
patterns.add<
|
|
AssumeAlignmentOpLowering,
|
|
DimOpLowering,
|
|
GlobalMemrefOpLowering,
|
|
GetGlobalMemrefOpLowering,
|
|
LoadOpLowering,
|
|
MemRefCastOpLowering,
|
|
MemRefReinterpretCastOpLowering,
|
|
MemRefReshapeOpLowering,
|
|
RankOpLowering,
|
|
StoreOpLowering,
|
|
SubViewOpLowering,
|
|
TransposeOpLowering,
|
|
ViewOpLowering>(converter);
|
|
// clang-format on
|
|
auto allocLowering = converter.getOptions().allocLowering;
|
|
if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
|
|
patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
|
|
else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
|
|
patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
|
|
}
|
|
|
|
void mlir::populateStdToLLVMFuncOpConversionPattern(
|
|
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
|
if (converter.getOptions().useBarePtrCallConv)
|
|
patterns.add<BarePtrFuncOpConversion>(converter);
|
|
else
|
|
patterns.add<FuncOpConversion>(converter);
|
|
}
|
|
|
|
void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns) {
|
|
populateStdToLLVMFuncOpConversionPattern(converter, patterns);
|
|
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
|
|
populateStdToLLVMMemoryConversionPatterns(converter, patterns);
|
|
}
|
|
|
|
/// Convert a non-empty list of types to be returned from a function into a
|
|
/// supported LLVM IR type. In particular, if more than one value is returned,
|
|
/// create an LLVM IR structure type with elements that correspond to each of
|
|
/// the MLIR types converted with `convertType`.
|
|
Type LLVMTypeConverter::packFunctionResults(TypeRange types) {
|
|
assert(!types.empty() && "expected non-empty list of type");
|
|
|
|
if (types.size() == 1)
|
|
return convertCallingConventionType(types.front());
|
|
|
|
SmallVector<Type, 8> resultTypes;
|
|
resultTypes.reserve(types.size());
|
|
for (auto t : types) {
|
|
auto converted = convertCallingConventionType(t);
|
|
if (!converted || !LLVM::isCompatibleType(converted))
|
|
return {};
|
|
resultTypes.push_back(converted);
|
|
}
|
|
|
|
return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
|
|
}
|
|
|
|
Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
|
|
OpBuilder &builder) {
|
|
auto *context = builder.getContext();
|
|
auto int64Ty = IntegerType::get(builder.getContext(), 64);
|
|
auto indexType = IndexType::get(context);
|
|
// Alloca with proper alignment. We do not expect optimizations of this
|
|
// alloca op and so we omit allocating at the entry block.
|
|
auto ptrType = LLVM::LLVMPointerType::get(operand.getType());
|
|
Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
|
|
IntegerAttr::get(indexType, 1));
|
|
Value allocated =
|
|
builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
|
|
// Store into the alloca'ed descriptor.
|
|
builder.create<LLVM::StoreOp>(loc, operand, allocated);
|
|
return allocated;
|
|
}
|
|
|
|
SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
|
|
ValueRange opOperands,
|
|
ValueRange operands,
|
|
OpBuilder &builder) {
|
|
SmallVector<Value, 4> promotedOperands;
|
|
promotedOperands.reserve(operands.size());
|
|
for (auto it : llvm::zip(opOperands, operands)) {
|
|
auto operand = std::get<0>(it);
|
|
auto llvmOperand = std::get<1>(it);
|
|
|
|
if (options.useBarePtrCallConv) {
|
|
// For the bare-ptr calling convention, we only have to extract the
|
|
// aligned pointer of a memref.
|
|
if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
|
|
MemRefDescriptor desc(llvmOperand);
|
|
llvmOperand = desc.alignedPtr(builder, loc);
|
|
} else if (operand.getType().isa<UnrankedMemRefType>()) {
|
|
llvm_unreachable("Unranked memrefs are not supported");
|
|
}
|
|
} else {
|
|
if (operand.getType().isa<UnrankedMemRefType>()) {
|
|
UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
|
|
promotedOperands);
|
|
continue;
|
|
}
|
|
if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
|
|
MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
|
|
promotedOperands);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
promotedOperands.push_back(llvmOperand);
|
|
}
|
|
return promotedOperands;
|
|
}
|
|
|
|
namespace {
|
|
/// A pass converting MLIR operations into the LLVM IR dialect.
|
|
struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
|
|
LLVMLoweringPass() = default;
|
|
LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers,
|
|
unsigned indexBitwidth, bool useAlignedAlloc,
|
|
const llvm::DataLayout &dataLayout) {
|
|
this->useBarePtrCallConv = useBarePtrCallConv;
|
|
this->emitCWrappers = emitCWrappers;
|
|
this->indexBitwidth = indexBitwidth;
|
|
this->useAlignedAlloc = useAlignedAlloc;
|
|
this->dataLayout = dataLayout.getStringRepresentation();
|
|
}
|
|
|
|
/// Run the dialect converter on the module.
|
|
void runOnOperation() override {
|
|
if (useBarePtrCallConv && emitCWrappers) {
|
|
getOperation().emitError()
|
|
<< "incompatible conversion options: bare-pointer calling convention "
|
|
"and C wrapper emission";
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
|
|
this->dataLayout, [this](const Twine &message) {
|
|
getOperation().emitError() << message.str();
|
|
}))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
ModuleOp m = getOperation();
|
|
const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
|
|
|
|
LowerToLLVMOptions options(&getContext(),
|
|
dataLayoutAnalysis.getAtOrAbove(m));
|
|
options.useBarePtrCallConv = useBarePtrCallConv;
|
|
options.emitCWrappers = emitCWrappers;
|
|
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
|
|
options.overrideIndexBitwidth(indexBitwidth);
|
|
options.allocLowering =
|
|
(useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
|
|
: LowerToLLVMOptions::AllocLowering::Malloc);
|
|
options.dataLayout = llvm::DataLayout(this->dataLayout);
|
|
|
|
LLVMTypeConverter typeConverter(&getContext(), options,
|
|
&dataLayoutAnalysis);
|
|
|
|
RewritePatternSet patterns(&getContext());
|
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
|
|
|
LLVMConversionTarget target(getContext());
|
|
if (failed(applyPartialConversion(m, target, std::move(patterns))))
|
|
signalPassFailure();
|
|
m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
|
|
StringAttr::get(m.getContext(), this->dataLayout));
|
|
}
|
|
};
|
|
} // end namespace
|
|
|
|
Value AllocLikeOpLLVMLowering::createAligned(
|
|
ConversionPatternRewriter &rewriter, Location loc, Value input,
|
|
Value alignment) {
|
|
Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
|
|
Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
|
|
Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
|
|
Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
|
|
return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
|
|
}
|
|
|
|
LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
|
|
Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
MemRefType memRefType = getMemRefResultType(op);
|
|
if (!isConvertibleAndHasIdentityMaps(memRefType))
|
|
return rewriter.notifyMatchFailure(op, "incompatible memref type");
|
|
auto loc = op->getLoc();
|
|
|
|
// Get actual sizes of the memref as values: static sizes are constant
|
|
// values and dynamic sizes are passed to 'alloc' as operands. In case of
|
|
// zero-dimensional memref, assume a scalar (size 1).
|
|
SmallVector<Value, 4> sizes;
|
|
SmallVector<Value, 4> strides;
|
|
Value sizeBytes;
|
|
this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes,
|
|
strides, sizeBytes);
|
|
|
|
// Allocate the underlying buffer.
|
|
Value allocatedPtr;
|
|
Value alignedPtr;
|
|
std::tie(allocatedPtr, alignedPtr) =
|
|
this->allocateBuffer(rewriter, loc, sizeBytes, op);
|
|
|
|
// Create the MemRef descriptor.
|
|
auto memRefDescriptor = this->createMemRefDescriptor(
|
|
loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
|
|
|
|
// Return the final value of the descriptor.
|
|
rewriter.replaceOp(op, {memRefDescriptor});
|
|
return success();
|
|
}
|
|
|
|
mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
|
|
: ConversionTarget(ctx) {
|
|
this->addLegalDialect<LLVM::LLVMDialect>();
|
|
this->addIllegalOp<LLVM::DialectCastOp>();
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::createLowerToLLVMPass() {
|
|
return std::make_unique<LLVMLoweringPass>();
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
|
|
auto allocLowering = options.allocLowering;
|
|
// There is no way to provide additional patterns for pass, so
|
|
// AllocLowering::None will always fail.
|
|
assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
|
|
"LLVMLoweringPass doesn't support AllocLowering::None");
|
|
bool useAlignedAlloc =
|
|
(allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
|
|
return std::make_unique<LLVMLoweringPass>(
|
|
options.useBarePtrCallConv, options.emitCWrappers,
|
|
options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
|
|
}
|
|
|
|
mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx)
|
|
: LowerToLLVMOptions(ctx, DataLayout()) {}
|
|
|
|
mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx,
|
|
const DataLayout &dl) {
|
|
indexBitwidth = dl.getTypeSizeInBits(IndexType::get(ctx));
|
|
}
|