Files
clang-p2996/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Alex Zinenko 684dfe8adb [mlir] factor out ConvertToLLVMPattern
This class and classes that extend it are general utilities for any dialect
that is being converted into the LLVM dialect. They are in no way specific to
Standard-to-LLVM conversion and should not make their users depend on it.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D105542
2021-07-08 10:27:05 +02:00

143 lines
5.7 KiB
C++

//===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
using namespace mlir;
// For >1-D vector types, extracts the necessary information to iterate over all
// 1-D subvectors in the underlying llrepresentation of the n-D vector
// Iterates on the llvm array type until we hit a non-array type (which is
// asserted to be an llvm vector type).
LLVM::detail::NDVectorTypeInfo
LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType,
LLVMTypeConverter &converter) {
assert(vectorType.getRank() > 1 && "expected >1D vector type");
NDVectorTypeInfo info;
info.llvmNDVectorTy = converter.convertType(vectorType);
if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
info.llvmNDVectorTy = nullptr;
return info;
}
info.arraySizes.reserve(vectorType.getRank() - 1);
auto llvmTy = info.llvmNDVectorTy;
while (llvmTy.isa<LLVM::LLVMArrayType>()) {
info.arraySizes.push_back(
llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
}
if (!LLVM::isCompatibleVectorType(llvmTy))
return info;
info.llvm1DVectorTy = llvmTy;
return info;
}
// Express `linearIndex` in terms of coordinates of `basis`.
// Returns the empty vector when linearIndex is out of the range [0, P] where
// P is the product of all the basis coordinates.
//
// Prerequisites:
// Basis is an array of nonnegative integers (signed type inherited from
// vector shape type).
SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis,
unsigned linearIndex) {
SmallVector<int64_t, 4> res;
res.reserve(basis.size());
for (unsigned basisElement : llvm::reverse(basis)) {
res.push_back(linearIndex % basisElement);
linearIndex = linearIndex / basisElement;
}
if (linearIndex > 0)
return {};
std::reverse(res.begin(), res.end());
return res;
}
// Iterate of linear index, convert to coords space and insert splatted 1-D
// vector in each position.
void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
OpBuilder &builder,
function_ref<void(ArrayAttr)> fun) {
unsigned ub = 1;
for (auto s : info.arraySizes)
ub *= s;
for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
auto coords = getCoordinates(info.arraySizes, linearIndex);
// Linear index is out of bounds, we are done.
if (coords.empty())
break;
assert(coords.size() == info.arraySizes.size());
auto position = builder.getI64ArrayAttr(coords);
fun(position);
}
}
LogicalResult LLVM::detail::handleMultidimensionalVectors(
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
SmallVector<Type> operand1DVectorTypes;
for (Value operand : op->getOperands()) {
auto operandNDVectorType = operand.getType().cast<VectorType>();
auto operandTypeInfo =
extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
}
auto resultTypeInfo =
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
auto loc = op->getLoc();
Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value, 4> extractedOperands;
for (auto operand : llvm::enumerate(operands)) {
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, operand1DVectorTypes[operand.index()], operand.value(),
position));
}
Value newVal = createOperand(result1DVectorTy, extractedOperands);
desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
newVal, position);
});
rewriter.replaceOp(op, desc);
return success();
}
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
assert(!operands.empty());
// Cannot convert ops if their operands are not of LLVM type.
if (!llvm::all_of(operands.getTypes(),
[](Type t) { return isCompatibleType(t); }))
return failure();
auto llvmNDVectorTy = operands[0].getType();
if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
ValueRange operands) {
OperationState state(op->getLoc(), targetOp);
state.addTypes(llvm1DVectorTy);
state.addOperands(operands);
state.addAttributes(op->getAttrs());
return rewriter.createOperation(state)->getResult(0);
};
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}